生活随笔
收集整理的這篇文章主要介紹了
基于tensorflow框架的神经网络结构处理mnist数据集
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
一、構建計算圖
準備訓練數據定義前向計算過程 Inference定義loss(loss,accuracy等scalar用tensorboard展示)定義訓練方法變量初始化保存計算圖
二、創建會話
summary對象處理喂入數據,得到觀測的loss,accuracy等用測試數據測試模型
import tensorflow
as tf
import numpy
as np
import os
from tensorflow
.examples
.tutorials
.mnist
import input_data
os
.environ
['TF_CPP_MIN_LOG_LEVEl']='3'
tf
.reset_default_graph
()
mnist
= input_data
.read_data_sets
('D:\MyData\zengxf\.keras\datasets\MNIST_data',one_hot
=True)
xq
,yq
= mnist
.train
.next_batch
(2)h1
= 100
h2
= 10with tf
.name_scope
("Input"):X
= tf
.placeholder
("float",[None,784],name
='X')Y_true
= tf
.placeholder
("float",[None,10],name
='Y_true')
with tf
.name_scope
("Inference"):with tf
.name_scope
("hidden1"):W1
= tf
.Variable
(tf
.random_normal
([784, h1
])*0.1, name
='W1')b1
= tf
.Variable
(tf
.zeros
([h1
]), name
='b1')y_1
= tf
.nn
.sigmoid
(tf
.matmul
(X
, W1
)+b1
)with tf
.name_scope
("hidden2"):W2
= tf
.Variable
(tf
.random_normal
([h1
, h2
])*0.1, name
='W2')b2
= tf
.Variable
(tf
.zeros
([h2
]), name
='b2')y_2
= tf
.nn
.sigmoid
(tf
.matmul
(y_1
, W2
)+b2
)with tf
.name_scope
("Output"):W3
= tf
.Variable
(tf
.truncated_normal
([h2
, 10])*0.1, name
='W3')b3
= tf
.Variable
(tf
.zeros
([10]), name
='b3')y
= tf
.nn
.softmax
(tf
.matmul
(y_2
, W3
)+ b3
)
with tf
.name_scope
("Loss"):loss
= tf
.reduce_mean
(-tf
.reduce_sum
(tf
.multiply
(Y_true
,tf
.log
(y
))))loss_scalar
= tf
.summary
.scalar
('loss',loss
)accuracy
= tf
.reduce_mean
(tf
.cast
(tf
.equal
(tf
.argmax
(y
,1),tf
.argmax
(Y_true
,1)),tf
.float32
))accuracy_scalar
= tf
.summary
.scalar
('accuracy', accuracy
)with tf
.name_scope
("Trian"):optimizer
= tf
.train
.GradientDescentOptimizer
(learning_rate
=0.01)train_op
= optimizer
.minimize
(loss
)init
= tf
.global_variables_initializer
()
merge_summary_op
= tf
.summary
.merge_all
()
writer
= tf
.summary
.FileWriter
('logs', tf
.get_default_graph
())
sess
= tf
.Session
()
sess
.run
(init
)
for step
in range(5000):train_x
,train_y
= mnist
.train
.next_batch
(500)_
,summary_op
,train_loss
,acc
= sess
.run
([train_op
,merge_summary_op
,loss
,accuracy
],feed_dict
={X
:train_x
,Y_true
:train_y
})if step
%100==99:print('loss=',train_loss
)writer
.add_summary
(summary_op
,step
)
print(sess
.run
(accuracy
, feed_dict
={X
: mnist
.test
.images
, Y_true
: mnist
.test
.labels
}))
writer
.close
()
總結
以上是生活随笔為你收集整理的基于tensorflow框架的神经网络结构处理mnist数据集的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。