ソースコードはそのまま流用していますが、line_profilerでの測定用に関数にしています。
x, t = get_data()
network = init_network()
def get_accuracy_with_batch(x, t, network):
batch_size = 100 #
accuracy_cnt = 0
for i in range(0, len(x), batch_size):
x_batch = x[i:i+batch_size]
y_batch = predict(network, x_batch)
p = np.argmax(y_batch, axis=1)
accuracy_cnt += np.sum(p== t[i:i+batch_size])
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
accuracy_cnt = 0
for i in range(len(x)):
y = predict(network, x[i])
p = np.argmax(y)
if p == t[i]:
accuracy_cnt += 1
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
という関数にしておいて、バッチにしているものと、そのままのものを連続して動かします。%lprun -f get_accuracy_with_batch get_accuracy_with_batch(x, t, network)
で測定してみると。確かに高速になっています(1/6くらい)。
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 def get_accuracy_with_batch(x, t, network):
4
5 1 2 2.0 0.0 batch_size = 100 #
6 1 1 1.0 0.0 accuracy_cnt = 0
7
8 101 115 1.1 0.0 for i in range(0, len(x), batch_size):
9 100 176 1.8 0.0 x_batch = x[i:i+batch_size]
10 100 58223 582.2 12.4 y_batch = predict(network, x_batch)
11 100 1111 11.1 0.2 p = np.argmax(y_batch, axis=1)
12 100 2029 20.3 0.4 accuracy_cnt += np.sum(p== t[i:i+batch_size])
13
14 1 116 116.0 0.0 print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
15
16 1 1 1.0 0.0 accuracy_cnt = 0
17 10001 4512 0.5 1.0 for i in range(len(x)):
18 10000 366764 36.7 77.8 y = predict(network, x[i])
19 10000 23556 2.4 5.0 p = np.argmax(y)
20
21 10000 9054 0.9 1.9 if p == t[i]:
22 9352 5339 0.6 1.1 accuracy_cnt += 1
23
24 1 171 171.0 0.0 print("Accuracy:" + str(float(accuracy_cnt) / len(x))