生活随笔
收集整理的這篇文章主要介紹了
tensorflow 就该这么学--2
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
1、模型機制
tensor? 代表數(shù)據(jù),可以理解為多維數(shù)組
variable? 代表變量,模型中表示定義的參數(shù),是通過不斷訓(xùn)練得到的值
placeholder? 代表占位符,也可以理解為定義函數(shù)的參數(shù)
2、session 的兩種使用方法(還有一種啟動session的方式是sess = tf.InteractiveSession())
3、注入機制
4、指定gpu運算
5、保存模型與載入模型
示例完整代碼如下可直接運行:
import?tensorflow?as?tf??import?numpy?as?np????plotdata?=?{?"batchsize":[],?"loss":[]?}??????train_X?=?np.linspace(-1,?1,?100)??train_Y?=?2?*?train_X?+?np.random.randn(*train_X.shape)?*?0.3?????tf.reset_default_graph()????????????X?=?tf.placeholder("float")??Y?=?tf.placeholder("float")????W?=?tf.Variable(tf.random_normal([1]),?name="weight")??b?=?tf.Variable(tf.zeros([1]),?name="bias")??????z?=?tf.multiply(X,?W)+?b??????cost?=tf.reduce_mean(?tf.square(Y?-?z))??learning_rate?=?0.01??optimizer?=?tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)???????init?=?tf.global_variables_initializer()????training_epochs?=?20??display_step?=?2????saver?=?tf.train.Saver()??savedir?=?'./'??????with?tf.Session()?as?sess:??????sess.run(init)??????????????for?epoch?in?range(training_epochs):??????????for?(x,?y)?in?zip(train_X,?train_Y):??????????????sess.run(optimizer,?feed_dict={X:?x,?Y:?y})??????????????????????if?epoch?%?display_step?==?0:??????????????loss?=?sess.run(cost,?feed_dict={X:?train_X,?Y:train_Y})??????????????print?("Epoch:",?epoch+1,?"cost=",?loss,"W=",?sess.run(W),?"b=",?sess.run(b))??????????????if?not?(loss?==?"NA"?):??????????????????plotdata["batchsize"].append(epoch)??????????????????plotdata["loss"].append(loss)????????print?("?Finished!")??????saver.save(sess,savedir+'linemodel.cpkt')???????print?("cost=",?sess.run(cost,?feed_dict={X:?train_X,?Y:?train_Y}),?"W=",?sess.run(W),?"b=",?sess.run(b))??????with?tf.Session()?as?sess2:??????sess2.run(tf.global_variables_initializer())??????saver.restore(sess2,savedir+'linemodel.cpkt')??????print('x=0.1,z=',sess2.run(z,feed_dict={X:0.1}))??
6、檢查點,訓(xùn)練模型有時候會出現(xiàn)中斷情況,可以將檢查點保存起來
saver一個參數(shù)max_to_keep=1表明最多只保存一個檢查點文件
載入時指定迭代次數(shù)load_epoch
完整代碼如下:
import?tensorflow?as?tf??import?numpy?as?np????plotdata?=?{?"batchsize":[],?"loss":[]?}??????train_X?=?np.linspace(-1,?1,?100)??train_Y?=?2?*?train_X?+?np.random.randn(*train_X.shape)?*?0.3?????tf.reset_default_graph()????????????X?=?tf.placeholder("float")??Y?=?tf.placeholder("float")????W?=?tf.Variable(tf.random_normal([1]),?name="weight")??b?=?tf.Variable(tf.zeros([1]),?name="bias")??????z?=?tf.multiply(X,?W)+?b??????cost?=tf.reduce_mean(?tf.square(Y?-?z))??learning_rate?=?0.01??optimizer?=?tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)???????init?=?tf.global_variables_initializer()????training_epochs?=?20??display_step?=?2????saver?=?tf.train.Saver(max_to_keep=1)?????????savedir?=?'./'??????with?tf.Session()?as?sess:??????sess.run(init)??????????????for?epoch?in?range(training_epochs):??????????for?(x,?y)?in?zip(train_X,?train_Y):??????????????sess.run(optimizer,?feed_dict={X:?x,?Y:?y})??????????????????????if?epoch?%?display_step?==?0:??????????????loss?=?sess.run(cost,?feed_dict={X:?train_X,?Y:train_Y})??????????????print?("Epoch:",?epoch+1,?"cost=",?loss,"W=",?sess.run(W),?"b=",?sess.run(b))??????????????if?not?(loss?==?"NA"?):??????????????????plotdata["batchsize"].append(epoch)??????????????????plotdata["loss"].append(loss)??????????????saver.save(sess,savedir+'linemodel.cpkt',global_step=epoch)????????print?("?Finished!")????????????print?("cost=",?sess.run(cost,?feed_dict={X:?train_X,?Y:?train_Y}),?"W=",?sess.run(W),?"b=",?sess.run(b))??????with?tf.Session()?as?sess2:??????load_epoch?=??18??????sess2.run(tf.global_variables_initializer())??????saver.restore(sess2,savedir+'linemodel.cpkt-'+str(load_epoch))??????print('x=0.1,z=',sess2.run(z,feed_dict={X:0.1}))??
模型操作常用函數(shù)
tf.train.Saver()? #創(chuàng)建存儲器Saver
tf.train.Saver.save(sess,save_path) #保存
tf.train.Saver.restore(sess,save_path) #恢復(fù)
7、可視化tensorboard
在代碼中加入模型相關(guān)操作tf.summary.., 代碼后面有注釋,這個不理解可以當作模版,這幾句代碼,放在不同代碼相應(yīng)位置即可
代碼如下:
import?tensorflow?as?tf??import?numpy?as?np????plotdata?=?{?"batchsize":[],?"loss":[]?}??????train_X?=?np.linspace(-1,?1,?100)??train_Y?=?2?*?train_X?+?np.random.randn(*train_X.shape)?*?0.3?????tf.reset_default_graph()??????????X?=?tf.placeholder("float")??Y?=?tf.placeholder("float")????W?=?tf.Variable(tf.random_normal([1]),?name="weight")??b?=?tf.Variable(tf.zeros([1]),?name="bias")??????z?=?tf.multiply(X,?W)+?b??tf.summary.histogram('z',z)??????cost?=tf.reduce_mean(?tf.square(Y?-?z))??tf.summary.scalar('loss_function',?cost)????learning_rate?=?0.01??optimizer?=?tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)???????init?=?tf.global_variables_initializer()????training_epochs?=?20??display_step?=?2??????with?tf.Session()?as?sess:??????sess.run(init)????????merged_summary_op?=?tf.summary.merge_all()??????????????summary_writer?=?tf.summary.FileWriter('log/summaries',?sess.graph)??????????????for?epoch?in?range(training_epochs):??????????for?(x,?y)?in?zip(train_X,?train_Y):??????????????sess.run(optimizer,?feed_dict={X:?x,?Y:?y})????????????????summary_str?=?sess.run(merged_summary_op,?feed_dict={X:?x,?Y:?y});??????????????summary_writer.add_summary(summary_str,?epoch);????????????????????????if?epoch?%?display_step?==?0:??????????????loss?=?sess.run(cost,?feed_dict={X:?train_X,?Y:train_Y})??????????????print?("Epoch:",?epoch+1,?"cost=",?loss,"W=",?sess.run(W),?"b=",?sess.run(b))??????????????if?not?(loss?==?"NA"?):??????????????????plotdata["batchsize"].append(epoch)??????????????????plotdata["loss"].append(loss)????????print?("?Finished!")????????print?("cost=",?sess.run(cost,?feed_dict={X:?train_X,?Y:?train_Y}),?"W=",?sess.run(W),?"b=",?sess.run(b))??
之后查看tensorboard,進入summary 日志的上級路徑中,輸入相關(guān)命令如下圖所示:
看見端口號為6006,在瀏覽器中輸入http://127.0.0.1:6006,就會看到下面界面
window系統(tǒng)下相關(guān)操作一樣,進入日志文件目錄,然后輸入tensorboard相應(yīng)的命令,在打開瀏覽器即可看到上圖(tensorboard)
總結(jié)
以上是生活随笔為你收集整理的tensorflow 就该这么学--2的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網(wǎng)站內(nèi)容還不錯,歡迎將生活随笔推薦給好友。