TensorFlow - MNISTのチュートリアル(専門家向け)を試してみる.
TensorFlow - MNISTのチュートリアル(初心者向け)を試してみる. - ゲームAI備忘録の続きです.
MNISTのチュートリアル(専門家向け)を試してみました.
https://www.tensorflow.org/versions/master/tutorials/mnist/pros/index.html
CNNをTensorFlowで構築してMNISTデータを分類するという内容.
下記パラメータで実験したところ,5分程度の学習時間で96%の精度となりました.(初心者向けだと91%)
時間をかければ,99.7%まで上がるみたいです.
パラメータ
実行
GitHub - namakemono/mnist-tensorflow
#!/usr/bin/env python # -*- coding: utf-8 -*- # mnist_expert.py import tensorflow as tf import os if not os.path.exists("input_data.py"): os.system("curl https://raw.githubusercontent.com/tensorflow/tensorflow/0.6.0/tensorflow/examples/tutorials/mnist/input_data.py -o input_data.py") import input_data def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME') def max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') def weight_variable(shape): return tf.Variable(tf.truncated_normal(shape, stddev=0.1)) def bias_variable(shape): return tf.Variable(tf.constant(0.1, shape=shape)) def main(args): mnist = input_data.read_data_sets('MNIST_data', one_hot=True) with tf.Session() as sess: # Variables x = tf.placeholder("float", shape=[None, 784]) y_ = tf.placeholder("float", shape=[None, 10]) # y' keep_prob = tf.placeholder("float") # Used for Dropout # Build a Multilayer Convolutional Network: INPUT -> [CONV -> RELU -> POOL] * 2 -> FC -> RELU -> FC W1, b1 = weight_variable([3,3,1,32]), bias_variable([32]) # 3x3 Filter, input channel: 1, output channel: 32 W2, b2 = weight_variable([3,3,32,64]), bias_variable([64]) # 3x3 Filter, input channel: 32, output channel: 64 W3, b3 = weight_variable([7*7*64,1024]), bias_variable([1024]) W4, b4 = weight_variable([1024,10]), bias_variable([10]) x_ = tf.reshape(x, [-1, 28, 28, 1]) # 28x28, channel=1 h1 = max_pool_2x2(tf.nn.relu(conv2d(x_, W1) + b1)) # First Convolutional Layer: CONV -> RELU -> POOL, image size: 28x28 -> 14x14 h2 = max_pool_2x2(tf.nn.relu(conv2d(h1, W2) + b2)) # Second Convolutional Layer: CONV -> RELU -> POOL, image size: 14x14 -> 7x7 h3 = tf.nn.relu(tf.matmul(tf.reshape(h2, [-1, 7*7*64]), W3) + b3) # Densely Connected Layer: FC -> RELU h4 = tf.nn.dropout(h3, keep_prob) # Dropout y = tf.nn.softmax(tf.matmul(h4, W4) + b4) # Readout Layer: FC # Train and Evaluate the Model cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # H = -Σ{y' * log(y) + (1-y') * log(1-y)} optimizer = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) sess.run(tf.initialize_all_variables()) for i in range(1000): images, labels = mnist.train.next_batch(50) optimizer.run(feed_dict={x: images, y_: labels, keep_prob: 0.5}) correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) if i % 100 == 0: print "[%d]\ttrain-accuracy:%.5f\ttest-accuracy:%.5f" % (i, accuracy.eval(feed_dict={x: images, y_: labels, keep_prob: 1.0}), accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})) print "Accuracy: %.3f" % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}) if __name__ == "__main__": tf.app.run()
実行結果
$ time python mnist_expert.py Extracting MNIST_data/train-images-idx3-ubyte.gz Extracting MNIST_data/train-labels-idx1-ubyte.gz Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz can't determine number of CPU cores: assuming 4 I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 4 can't determine number of CPU cores: assuming 4 I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 4 [0] test-accuracy:0.10000 test-accuracy:0.10110 [100] test-accuracy:0.80000 test-accuracy:0.77120 [200] test-accuracy:0.94000 test-accuracy:0.87740 [300] test-accuracy:0.86000 test-accuracy:0.89560 [400] test-accuracy:0.98000 test-accuracy:0.91680 [500] test-accuracy:0.90000 test-accuracy:0.92050 [600] test-accuracy:0.96000 test-accuracy:0.93260 [700] test-accuracy:0.90000 test-accuracy:0.93360 [800] test-accuracy:0.88000 test-accuracy:0.94250 [900] test-accuracy:1.00000 test-accuracy:0.94650 Accuracy: 0.951 real 7m39.909s user 12m24.497s sys 2m26.872s