Explore "Full-Stack" in depth!

情報系の専門学校で、今は機械学習に的を絞って学習中。プログラミングを趣味でやりつつ、IT系のあらゆる知識と技術を身に付けるべく奮闘中。

最近人気なNumbaを深く考察する。

目次

概要

※前半Numbaと全然関係ないです。

Pythonは動作が遅い」という話はよく聞きます。
そして、私自身そう思うときもあります

しかし、機械学習タスクで圧倒的に使いやすいのがPythonであることも事実です。
今日に至るまでScikit-learnNumpy等、
機械学習に用いられるライブラリは進化し続けてきました。

ここでは C++使えばいい」 という身も蓋もない話には耳をふさぎ、
Pythonでの様々なコードを高速化するNumbaについて深く考えます。

具体的には、なぜ早くなるのか?を考えます。


まずはコードの書き方を変える

タイトルの通りNumbaを扱う記事になりますが、
まずはコードの書き方を変える事を考えます。

参考:実行時間の計測

恐らくベンチマークを取るライブラリはいくつかあると思いますが、
今回はContextmanagerを使った方法を紹介しておきます。

from contextlib import contextmanager
import time
import logzero

logger = logzero.setup_logger(
        name=f'{__name__}_logger',
        logfile=f'{__name__}.log',
        level=10,
        formatter=None,
        maxBytes=1000,
        backupCount=3,
        fileLoglevel=30,
        disableStderrLogger=False)

@contextmanager
def timer(phase):
    before = time.time()
    yield
    logger.debug(f'[{phase}] done in {time.time() - before:.3f}s')

簡単に解説です。

まず、yieldは渡されたブロックが入ります。

つまり、

Rubyでいえば、

def sum(x,y,&block)
  yield if block_given?
  return x+y
end

puts sum(30,50){puts "計算します..."}
計算します...
80

というような感じです。
渡されたブロックを呼び出していますよね。

Pythonでも似たような事が実現できます。

with timer('training model'):
  #処理が始まる前にbeforeとして現在時刻を保存
  #Processing...
  #処理終わったらログが出力される
[D 190121 13:49:34 optimize:20] [example] done in 3.000s

かなり使いやすいのでおすすめです。

例えばファイルパスを渡して開き、最後に必ずストリームを閉じるような使い方もできそうです。


機械学習で良く用いられるのが配列( pd.DataFrame )の計算です。

素数が大きいと処理内容に時間が掛かりそうですが…。

import numpy as np
a = np.array([x * np.random.rand(1)[0] for x in range(1,100001)])
b = np.array([x * np.random.rand(1)[0] for x in range(1,100001)])

のように、100kの要素数を持つnp.arrayを定義してみます。

各要素全てを合計する。

a = [1,3,5,7,9]
b = [2,4,6,8,10]
print(sum(a)+sum(b)) #=>55

のような計算を、巨大なnp.arrayに対して行ってみます。

組み込み関数sum()を使う(速度:普通)

def sum2array(x:np.array,x:np.array) -> int:
  return int(sum(x)+sum(y))

with timer('Hugelist'):
  print(sum2array(a,b))

組み込み関数sum()を愚直に使った計算時間を見てみます。

f:id:orangebladdy:20190121175830j:plain

結果は0.2秒となりました。

思ったより早かったです。

np.sum()を使う(速度:高速)

Numpyは大部分がC++で実装されているようなので、
np.sum()も早いという噂を聞いたことがあります。

def sum2array(x:np.array,x:np.array) -> int:
  return int(np.sum(x)+np.sum(y))

with timer('Hugelist'):
  print(sum2array(a,b))

f:id:orangebladdy:20190121175851j:plain

なんと20分の1まで下がりました。

配列の合計はnp.sum()が最も早そうです。


フィボナッチ数列を求める。

今度はNumbaの解説記事で頻繁に用いられる、
フィボナッチ数列n番目の数を求める関数を見ていきます。

フィボナッチ数列を使う、というのはあるあるですが、
今回の関数定義は自前で作りました。間違ってるかも…?

def fib(n):
    return fib(n-1)+fib(n-2) if n>2 else 1

のような関数を定義してみました。

JIT(実行速度:遅い)

with timer('fib'):
  print(fib(40)) #=>102334155

f:id:orangebladdy:20190121175915j:plain

大体20秒程度かかりました。

JIT(実行速度:爆速)

@jit
def fib(n):
  return fib(n-1)+fib(n-2) if n>2 else 1

f:id:orangebladdy:20190121175956j:plain

33.3分の1になりました。
何故早くなるのか、後ほど見ていきます。

明示的型付けJIT(実行速度:更に爆速)

@jit('i8(i8)')
def fib(n):
  return fib(n-1)+fib(n-2) if n>2 else 1

f:id:orangebladdy:20190121180059j:plain

@jit(戻り値の型(引数...の型))として型指定できます。

さらに0.15秒程度早くなりました。

NumbaとJITについて理解する

使い方は各自ドキュメントをご覧いただくとして、
ここでは Numbaユースケース、最適化される場合を紹介しておきます。

NumbaはNumpy配列や関数を用いたコード(ループ)を最適化する為のJITコンパイラです。
@jitデコレータを関数定義時に挟む使い方が最も汎用的です。
@jitを付随して定義した関数は,
JITコンパイラによって一部又は全てがバイトコードに変換されます。

詳しく解説します。

Pythonの公式実装(CPython)はC言語で出来ているので、
その組み込み関数やインタプリタもCで実装されています。

実際にのSumが実装されているであろうコードを載せてみます。

static PyObject*
builtin_sum(PyObject *self, PyObject *args)
{
  //省略
}

これらはCコンパイラによって事前コンパイルされます。

しかし、

def func(a,b):
  return a**b

このように定義された関数は、
Pythonインタプリタによって一行ずつ評価されていくはずです。

つまり、関数呼び出し時には

  • 変数funcが定義されているか探索
  • 見つかれば渡された引数を一時的にコピー
  • 関数内の処理を一行ずつ実行し
  • 値を返す

ということが行われています(リターンアドレス云々レベルはおいておいて抽象的にまとめました。)

つまり、組み込み関数はバイトコードを参照していますが、
新しく定義した関数はPythonの処理系がPythonの関数定義を参照しているという形になると思います。

@jit
def func(a,b):
  return a**b

として定義した関数は、JITコンパイラによってバイトコードに変換されます。

つまり、呼び出し回数が多い関数はNumbaによるバイトコード化が有効となりそうです。

呼び出し回数はほぼ指数関数的に増加するので、

フィボナッチ数列に渡すnの大きさと、
それに対する実行時間でグラフを描画してみましょう。

import time
import numpy as np
from numba import jit
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
sns.set()

@jit
def usejit(n):
  return usejit(n-1)+usejit(n-2) if n>2 else 1

def nonusejit(n):
  return nonusejit(n-1)+nonusejit(n-2) if n>2 else 1

nonuse = []
use = []

nonuse_df = pd.DataFrame()
use_df = pd.DataFrame()
nlist = [x for x in range(1,41)]
for n in nlist:
    before = time.time()
    nonusejit(n)
    nonuse.append(round(time.time() - before,5))
    before = time.time()
    usejit(n)
    use.append(round(time.time() - before,5))
nonuse = np.array(nonuse)
use = np.array(use)
nonuse_df['Proctime'] = nonuse
nonuse_df['num'] = nlist
use_df['Proctime'] = use
use_df['num'] = nlist

sns.relplot(x='Proctime',y='num',data=nonuse_df,kind='line')
plt.savefig('nonuse.png')
sns.relplot(x='Proctime',y='num',data=use_df,kind='line')
plt.savefig('use.png')

JIT

f:id:orangebladdy:20190121180256p:plain

35ぐらいからどんどん増えてってますね。
この増え幅に注目しましょう。
関数の呼び出し回数に対してメキメキ増えています。

JIT

f:id:orangebladdy:20190121180532p:plain

ちょっと一回謎のブレがありますが、
大きくなっていっても小数点第一位以下の伸びだとわかります。

関数呼び出し回数は変わっていないですが、
バイトコードに変換されているので爆速になっていますね。

参考:Pythonの関数定義

Python実際に関数が呼び出されるまで中のコードを検査しない(シンタックスエラーは除く)ので、

def func(a,b):
  fsjkadflsdjafkldsjaklfjasd

とかいう謎コードを書いても、
実際に関数が呼び出されるまでfsjkadflsdjafkldsjaklfjasdの内容を検査しません。

そういう変数が定義されていれば呼び出してもその値が評価されます。
インポートしていないライブラリを利用する関数を定義しても実際に呼び出すまでエラーを吐かないのはそれです。


総評

最近よく使われるNumbaについて自分でいろいろ考えてみて、なんとなく納得できました。