目次
概要
※前半Numbaと全然関係ないです。
「Pythonは動作が遅い」という話はよく聞きます。
そして、私自身そう思うときもあります。
しかし、機械学習タスクで圧倒的に使いやすいのがPythonであることも事実です。
今日に至るまでScikit-learnやNumpy等、
機械学習に用いられるライブラリは進化し続けてきました。
ここでは 「C++使えばいい」 という身も蓋もない話には耳をふさぎ、
Pythonでの様々なコードを高速化するNumbaについて深く考えます。
具体的には、なぜ早くなるのか?を考えます。
まずはコードの書き方を変える
タイトルの通りNumbaを扱う記事になりますが、
まずはコードの書き方を変える事を考えます。
参考:実行時間の計測
恐らくベンチマークを取るライブラリはいくつかあると思いますが、
今回はContextmanager
を使った方法を紹介しておきます。
from contextlib import contextmanager import time import logzero logger = logzero.setup_logger( name=f'{__name__}_logger', logfile=f'{__name__}.log', level=10, formatter=None, maxBytes=1000, backupCount=3, fileLoglevel=30, disableStderrLogger=False) @contextmanager def timer(phase): before = time.time() yield logger.debug(f'[{phase}] done in {time.time() - before:.3f}s')
簡単に解説です。
まず、yield
は渡されたブロックが入ります。
つまり、
Rubyでいえば、
def sum(x,y,&block) yield if block_given? return x+y end puts sum(30,50){puts "計算します..."}
計算します... 80
というような感じです。
渡されたブロックを呼び出していますよね。
Pythonでも似たような事が実現できます。
with timer('training model'): #処理が始まる前にbeforeとして現在時刻を保存 #Processing... #処理終わったらログが出力される
[D 190121 13:49:34 optimize:20] [example] done in 3.000s
かなり使いやすいのでおすすめです。
例えばファイルパスを渡して開き、最後に必ずストリームを閉じるような使い方もできそうです。
機械学習で良く用いられるのが配列( pd.DataFrame
)の計算です。
要素数が大きいと処理内容に時間が掛かりそうですが…。
import numpy as np a = np.array([x * np.random.rand(1)[0] for x in range(1,100001)]) b = np.array([x * np.random.rand(1)[0] for x in range(1,100001)])
のように、100kの要素数を持つnp.array
を定義してみます。
各要素全てを合計する。
a = [1,3,5,7,9] b = [2,4,6,8,10] print(sum(a)+sum(b)) #=>55
のような計算を、巨大なnp.array
に対して行ってみます。
組み込み関数sum()
を使う(速度:普通)
def sum2array(x:np.array,x:np.array) -> int: return int(sum(x)+sum(y)) with timer('Hugelist'): print(sum2array(a,b))
組み込み関数sum()
を愚直に使った計算時間を見てみます。
結果は0.2秒となりました。
思ったより早かったです。
np.sum()
を使う(速度:高速)
Numpyは大部分がC++で実装されているようなので、
np.sum()
も早いという噂を聞いたことがあります。
def sum2array(x:np.array,x:np.array) -> int: return int(np.sum(x)+np.sum(y)) with timer('Hugelist'): print(sum2array(a,b))
なんと20分の1まで下がりました。
配列の合計はnp.sum()
が最も早そうです。
フィボナッチ数列を求める。
今度はNumbaの解説記事で頻繁に用いられる、
フィボナッチ数列のn番目の数を求める関数を見ていきます。
フィボナッチ数列を使う、というのはあるあるですが、
今回の関数定義は自前で作りました。間違ってるかも…?
def fib(n): return fib(n-1)+fib(n-2) if n>2 else 1
のような関数を定義してみました。
非JIT(実行速度:遅い)
with timer('fib'): print(fib(40)) #=>102334155
大体20秒程度かかりました。
JIT(実行速度:爆速)
@jit def fib(n): return fib(n-1)+fib(n-2) if n>2 else 1
約33.3分の1になりました。
何故早くなるのか、後ほど見ていきます。
明示的型付けJIT(実行速度:更に爆速)
@jit('i8(i8)') def fib(n): return fib(n-1)+fib(n-2) if n>2 else 1
@jit(戻り値の型(引数...の型))
として型指定できます。
さらに0.15秒程度早くなりました。
NumbaとJITについて理解する
使い方は各自ドキュメントをご覧いただくとして、
ここでは Numba のユースケース、最適化される場合を紹介しておきます。
NumbaはNumpy配列や関数を用いたコード(ループ)を最適化する為のJITコンパイラです。
@jit
デコレータを関数定義時に挟む使い方が最も汎用的です。
@jit
を付随して定義した関数は,
JITコンパイラによって一部又は全てがバイトコードに変換されます。
詳しく解説します。
Pythonの公式実装(CPython)はC言語で出来ているので、
その組み込み関数やインタプリタもCで実装されています。
実際にのSumが実装されているであろうコードを載せてみます。
static PyObject* builtin_sum(PyObject *self, PyObject *args) { //省略 }
しかし、
def func(a,b): return a**b
このように定義された関数は、
Pythonのインタプリタによって一行ずつ評価されていくはずです。
つまり、関数呼び出し時には
- 変数
func
が定義されているか探索 - 見つかれば渡された引数を一時的にコピー
- 関数内の処理を一行ずつ実行し
- 値を返す
ということが行われています(リターンアドレス云々レベルはおいておいて抽象的にまとめました。)
つまり、組み込み関数はバイトコードを参照していますが、
新しく定義した関数はPythonの処理系がPythonの関数定義を参照しているという形になると思います。
@jit def func(a,b): return a**b
として定義した関数は、JITコンパイラによってバイトコードに変換されます。
つまり、呼び出し回数が多い関数はNumbaによるバイトコード化が有効となりそうです。
呼び出し回数はほぼ指数関数的に増加するので、
フィボナッチ数列に渡すnの大きさと、
それに対する実行時間でグラフを描画してみましょう。
import time import numpy as np from numba import jit import matplotlib.pyplot as plt import seaborn as sns import pandas as pd sns.set() @jit def usejit(n): return usejit(n-1)+usejit(n-2) if n>2 else 1 def nonusejit(n): return nonusejit(n-1)+nonusejit(n-2) if n>2 else 1 nonuse = [] use = [] nonuse_df = pd.DataFrame() use_df = pd.DataFrame() nlist = [x for x in range(1,41)] for n in nlist: before = time.time() nonusejit(n) nonuse.append(round(time.time() - before,5)) before = time.time() usejit(n) use.append(round(time.time() - before,5)) nonuse = np.array(nonuse) use = np.array(use) nonuse_df['Proctime'] = nonuse nonuse_df['num'] = nlist use_df['Proctime'] = use use_df['num'] = nlist sns.relplot(x='Proctime',y='num',data=nonuse_df,kind='line') plt.savefig('nonuse.png') sns.relplot(x='Proctime',y='num',data=use_df,kind='line') plt.savefig('use.png')
非JIT
35ぐらいからどんどん増えてってますね。
この増え幅に注目しましょう。
関数の呼び出し回数に対してメキメキ増えています。
JIT
ちょっと一回謎のブレがありますが、
大きくなっていっても小数点第一位以下の伸びだとわかります。
関数呼び出し回数は変わっていないですが、
バイトコードに変換されているので爆速になっていますね。
参考:Pythonの関数定義
Pythonは実際に関数が呼び出されるまで中のコードを検査しない(シンタックスエラーは除く)ので、
def func(a,b): fsjkadflsdjafkldsjaklfjasd
とかいう謎コードを書いても、
実際に関数が呼び出されるまでfsjkadflsdjafkldsjaklfjasdの内容を検査しません。
そういう変数が定義されていれば呼び出してもその値が評価されます。
インポートしていないライブラリを利用する関数を定義しても実際に呼び出すまでエラーを吐かないのはそれです。
総評
最近よく使われるNumbaについて自分でいろいろ考えてみて、なんとなく納得できました。