2017年6月28日水曜日

ニューラルネットワークにおけるバッチ処理の高速化測定メモ

オライリー本で、ニューラルネットワークがバッチ処理で高速になる、という話があり(ディープラーニング/斎藤, 3.6.3章 バッチ処理)、どれほど速くなるのかを見てみました。

ソースコードはそのまま流用していますが、line_profilerでの測定用に関数にしています。
x, t = get_data()
network = init_network()
def get_accuracy_with_batch(x, t, network):

    batch_size = 100 #
    accuracy_cnt = 0
    for i in range(0, len(x), batch_size):
        x_batch = x[i:i+batch_size]
        y_batch = predict(network, x_batch)
        p = np.argmax(y_batch, axis=1)
        accuracy_cnt += np.sum(p== t[i:i+batch_size])
    print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

    accuracy_cnt = 0
    for i in range(len(x)):
        y = predict(network, x[i])
        p = np.argmax(y)
        if p == t[i]:
            accuracy_cnt += 1
    print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
という関数にしておいて、バッチにしているものと、そのままのものを連続して動かします。


%lprun -f get_accuracy_with_batch get_accuracy_with_batch(x, t, network)

で測定してみると。確かに高速になっています(1/6くらい)。
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           def get_accuracy_with_batch(x, t, network):
     4                                             
     5         1            2      2.0      0.0      batch_size = 100 #
     6         1            1      1.0      0.0      accuracy_cnt = 0
     7                                         
     8       101          115      1.1      0.0      for i in range(0, len(x), batch_size):
     9       100          176      1.8      0.0          x_batch = x[i:i+batch_size]
    10       100        58223    582.2     12.4          y_batch = predict(network, x_batch)
    11       100         1111     11.1      0.2          p = np.argmax(y_batch, axis=1)
    12       100         2029     20.3      0.4          accuracy_cnt += np.sum(p== t[i:i+batch_size])
    13                                         
    14         1          116    116.0      0.0      print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
    15                                             
    16         1            1      1.0      0.0      accuracy_cnt = 0
    17     10001         4512      0.5      1.0      for i in range(len(x)):
    18     10000       366764     36.7     77.8          y = predict(network, x[i])
    19     10000        23556      2.4      5.0          p = np.argmax(y)
    20                                         
    21     10000         9054      0.9      1.9          if p == t[i]:
    22      9352         5339      0.6      1.1              accuracy_cnt += 1
    23                                         
    24         1          171    171.0      0.0      print("Accuracy:" + str(float(accuracy_cnt) / len(x))

0 件のコメント:

コメントを投稿