生活随笔
收集整理的這篇文章主要介紹了
Tensorflow实现MNIST数据自编码(3)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
前面自編碼(1)和自編碼(2)是針對高維數據維數進行降低維數角度改進模型,但是還需要讓這些特征具有抗干擾能力,輸入的特征數據受到干擾時,生成特征依然不會怎么變化,使自動編碼器具有更好的泛化能力
??import?tensorflow?as?tf??import?numpy?as?np??import?matplotlib.pyplot?as?plt??????from?tensorflow.examples.tutorials.mnist?import?input_data??mnist?=?input_data.read_data_sets('/data/',one_hot=True)????train_x?=?mnist.train.images??train_y?=?mnist.train.labels??test_x?=?mnist.test.images??test_y?=?mnist.test.labels????n_hidden_1?=?256???????n_input?=?784????x?=?tf.placeholder('float',[None,n_input])??y?=?tf.placeholder('float',[None,n_input])??dropout_keep_prob?=?tf.placeholder('float')????weights?=?{??????'h1':tf.Variable(tf.random_normal([n_input,n_hidden_1])),??????'h2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_1])),??????'out':tf.Variable(tf.random_normal([n_hidden_1,n_input])),??}??biases?=?{??????'b1':tf.Variable(tf.zeros([n_hidden_1])),??????'b2':tf.Variable(tf.zeros([n_hidden_1])),??????'out':tf.Variable(tf.zeros([n_input]))??}??def?denoise_auto_encoder(X,weights,biases,keep_prob):??????layer_1?=?tf.nn.sigmoid(tf.add(tf.matmul(X,weights['h1']),biases['b1']))??????layer_1out?=?tf.nn.dropout(layer_1,keep_prob)????????layer_2?=?tf.nn.sigmoid(tf.add(tf.matmul(layer_1out,weights['h2']),biases['b2']))??????layer_2out?=?tf.nn.dropout(layer_2,keep_prob)????????return?tf.nn.sigmoid(tf.matmul(layer_2out,weights['out'])+biases['out'])????reconstruction?=?denoise_auto_encoder(x,weights,biases,dropout_keep_prob)??cost?=?tf.reduce_mean(tf.pow(reconstruction-y,2))??optm?=?tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)????epochs?=?20??batch_size?=?256??disp_step?=?2????with?tf.Session()?as?sess:??????sess.run(tf.global_variables_initializer())??????print('Start?training')??????for?epoch?in?range(epochs):??????????num_batch?=?int(mnist.train.num_examples/batch_size)??????????total_cost?=?0??????????for?i?in?range(num_batch):??????????????batch_xs,batch_ys?=?mnist.train.next_batch(batch_size)??????????????batch_xs_noisy?=?batch_xs+0.3*np.random.randn(batch_size,784)??????????????feeds?=?{x:batch_xs_noisy,y:batch_xs,dropout_keep_prob:1.0}??????????????sess.run(optm,feed_dict=feeds)??????????????total_cost+=sess.run(cost,feed_dict=feeds)????????????????if?epoch%disp_step==0:??????????????????print('Epoch?%2d/%2d?average?cost:%.6f'%(epoch,epochs,total_cost/num_batch))????????print('Finished')??
總結
以上是生活随笔為你收集整理的Tensorflow实现MNIST数据自编码(3)的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。