tensorflow实现CNN识别手写数字
上一篇使用TensorFlow識別手寫數字,是直接采用了softmax進行多分類,直接將28*28的圖片轉換成為了784維的向量作為輸入,然后通過一個(784,10)的權重,將輸入轉換成一個10維的向量,最后再將對每一個數字預測一個概率,概率最大的數字就是預測的結果。因為,直接將圖片轉成一個784維的向量,丟棄了圖片原有的結構信息,但是最后對于測試集準確率還是可以達到91%。這一篇,介紹通過CNN來實現手寫數字的識別,準確率可以達到98%。
一、CNN(卷積神經網絡? ?convolutional neural network)
今年,可以說是人工智能被炒的最火的一年了。以致于python的使用快趕上java,人工智能能夠這么火,當然也離不開CNN。CNN的使用范圍也很廣在語音識別、自然語言處理、圖像處理都能看見它的身影。剛剛開始聽見卷積的時候,給我一種深不可測的感覺。在網上也看了很多關于卷積的文章,對于卷積的介紹也是非常詳細的,這里我推薦三篇文章對于卷積的介紹寫的還是非常不錯的(PS是英文的)
下面,簡單的畫一個圖,如何通過卷積來達到目標。當,我們在訓練一個卷積神經網絡的時候,需要大量的數據、設置卷積的層數、卷積核大小、池化的方式(最大、平均)、損失函數、設置目標(如:手寫數字的識別,輸出一個10維向量),然后讓卷積神經網絡通過不斷的訓練更新參數向我們設置的目標靠近,最后我們可以通過這些參數來預測樣本。至于,卷積的工作方式看上面的推薦的文章,有非常詳細的介紹。可能你會覺得這個東西有點抽象,網上有大牛將卷積的過程可視化,具體的可以參考它的git項目
https://github.com/yosinski/deep-visualization-toolbox
二、卷積神經網絡的結構圖
上面的流程圖,是整個卷積網絡的一個結構圖。我們在使用TensorFlow實現這個結構的時候,其實還是非常簡單的,只需要設置卷積核的大小,這里設計的是5*5,邊距的填充方式,卷積的個數、激活函數、池化的方式、輸出類別的個數,在最后我會給出使用TensorFlow實現整個結構的代碼。下面我會對這個結構提幾個問題并解答:
1、為什么卷積核的大小要設置成5*5,需要32個卷積?
卷積核的大小其實你可以自己隨便設置,如:3*3、5*5、7*7、9*9等,一般都為奇數,卷積的個數也是自己設置的,32個卷積的意思,代表的是你要提取原圖上32個特征(每一個卷積提取一種特征)。
2、為什么要使用RELU激活函數?
設置激活函數的目的是保證結果輸出的非線性化,RELU激活函數需要大于一個閾值,才會有輸出,和人的神經元結構很像,激活函數的種類有很多,RELU的變種就有很多,在卷積神經網絡中經常使用的激活函數有RELU和tanh。
3、為什么通過一個5*5的卷積和池化之后,原圖28*28的圖像就變成了14*14?
輸入28*28的圖像通過5*5的卷積之后,輸出還是28*28,這和卷積的方式有關,設置步長為1,如果對28*28的圖像設置不填充邊距,那么輸出圖像的大小應該是(28-5)/1 + 1,輸出圖像應該是24*24,如果我們將原圖的填充邊距設置為2(在原圖的周圍填充兩圈全0),來保證輸入圖像和輸出圖像的大小一致,這個時候的計算公式(28-5+2*2)/1 + 1,輸出圖像的大小還是和原圖保持一致。這樣做的目的,是為了防止輸入圖像經過卷積之后過快的衰減,因為有時候我們設計的卷積網絡層數可能達到上100層,而填充0并不會對結果有影響。常見池化的方式有兩種,均值和最大值,池化核的大小設置為2*2,代表是從2*2中選出一個值(平均值或者最大值),所以一個28*28的圖像再經過2*2的池化之后就變成了14*14,池化的目的是為了減少參數而且還可以很好的保證圖像的特征。
三、實現代碼
from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf#初始化權重函數 def weight_variable(shape):initial = tf.truncated_normal(shape,stddev=0.1);return tf.Variable(initial)#初始化偏置項 def bias_variable(shape):initial = tf.constant(0.1,shape=shape)return tf.Variable(initial)#定義卷積函數 def conv2d(x,w):return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')#定義一個2*2的最大池化層 def max_pool_2_2(x):return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')if __name__ == "__main__":#定義輸入變量x = tf.placeholder("float",shape=[None,784])#定義輸出變量y_ = tf.placeholder("float",shape=[None,10])#初始化權重,第一層卷積,32的意思代表的是輸出32個通道# 其實,也就是設置32個卷積,每一個卷積都會對圖像進行卷積操作w_conv1 = weight_variable([5,5,1,32])#初始化偏置項b_conv1 = bias_variable([32])#將輸入的x轉成一個4D向量,第2、3維對應圖片的寬高,最后一維代表圖片的顏色通道數# 輸入的圖像為灰度圖,所以通道數為1,如果是RGB圖,通道數為3# tf.reshape(x,[-1,28,28,1])的意思是將x自動轉換成28*28*1的數組# -1的意思是代表不知道x的shape,它會按照后面的設置進行轉換x_image = tf.reshape(x,[-1,28,28,1])# 卷積并激活h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1) + b_conv1)#池化h_pool1 = max_pool_2_2(h_conv1)#第二層卷積#初始權重w_conv2 = weight_variable([5,5,32,64])#初始化偏置項b_conv2 = bias_variable([64])#將第一層卷積池化后的結果作為第二層卷積的輸入h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2) + b_conv2)#池化h_pool2 = max_pool_2_2(h_conv2)# 設置全連接層的權重w_fc1 = weight_variable([7*7*64,1024])# 設置全連接層的偏置b_fc1 = bias_variable([1024])# 將第二層卷積池化后的結果,轉成一個7*7*64的數組h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])# 通過全連接之后并激活h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1) + b_fc1)# 防止過擬合keep_prob = tf.placeholder("float")h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob)#輸出層w_fc2 = weight_variable([1024,10])b_fc2 = bias_variable([10])y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2) + b_fc2)#日志輸出,每迭代100次輸出一次日志#定義交叉熵為損失函數cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))#最小化交叉熵train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)#計算準確率correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))sess = tf.Session()sess.run(tf.initialize_all_variables())# 下載minist的手寫數字的數據集mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)for i in range(20000):batch = mnist.train.next_batch(50)if i % 100 == 0:train_accuracy = accuracy.eval(session=sess,feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})print("step %d,training accuracy %g"%(i,train_accuracy))train_step.run(session = sess,feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})print("test accuracy %g" % accuracy.eval(session=sess,feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))#test accuracy 0.9919
總結
以上是生活随笔為你收集整理的tensorflow实现CNN识别手写数字的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: linux查看网卡驱动,Linux操作系
- 下一篇: GANs简述 Generative Ad