生活随笔
收集整理的這篇文章主要介紹了
RNN的手写数字识别
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
RNN和LSTM的原理可以看這篇文章
以下是RNN在手寫數字識別上的簡單應用
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import rnn#重置graph,不然會出錯
tf.reset_default_graph()# 導入數據
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)# 輸入圖片28*28
n_inputs = 28 # 輸入一行,28個像素點
max_time = 28 # 共28行
lstm_size = 100 # 隱層單元
n_classes = 10
batch_size = 50
n_batch = mnist.train.num_examples // batch_size# None表示第一個維度可以是任意的長度
# 定義兩個placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])# 初始化權值
weights = tf.Variable(tf.truncated_normal([lstm_size, n_classes], stddev=0.1))
# 初始化偏置值
biases = tf.Variable(tf.constant(0.1, shape=[n_classes]))# 定義RNN網絡
def RNN(X, weights, biases):# inputs=[batbatch_size,max_time,n_input]inputs = tf.reshape(X, [-1, max_time, n_inputs])# 定義LSTM基本celllstm_cell = rnn.BasicLSTMCell(lstm_size)# 1.0版本改了很多# 原代碼是這樣的:# lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(lstm_size)# 應該改為:# from tensorflow.contrib import rnn# lstm_cell = rnn.BasicLSTMCell(lstm_size)# final_state[0]是cell_state# final_state[1]是hidden_stateoutputs, final_state = tf.nn.dynamic_rnn(lstm_cell, inputs, dtype=tf.float32)results = tf.nn.softmax(tf.matmul(final_state[1], weights) + biases)return results# 計算RNN的返回結果
prediction = RNN(x, weights, biases)
# 損失函數
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
# 使用Adam
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
# 結果存放在一個布爾列表中
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
# 求準確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(6):for batch in range(n_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})print('iter' + str(epoch) + ',testing accuracy=' + str(acc))
輸出結果是這樣:
iter0,testing accuracy=0.7495
iter1,testing accuracy=0.841
iter2,testing accuracy=0.89
iter3,testing accuracy=0.9112
iter4,testing accuracy=0.9081
iter5,testing accuracy=0.9211
總結
以上是生活随笔為你收集整理的RNN的手写数字识别的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。