5.2 TensorFlow:模型的加载,存储,实例
背景
之前已經寫過TensorFlow圖與模型的加載與存儲了,寫的很詳細,但是或聞有人沒看懂,所以在附上一個關于模型加載與存儲的例子,CODE是我偶然看到了,就記下來了.其中模型很巧妙,比之前numpy寫一大堆簡單多了,這樣有利于把主要注意力放在模型的加載與存儲上.
解析
創建保存文件的類:saver = tf.train.Saver()
saver = tf.train.Saver() ,即為常見保存模型,圖,數據的類,其內部結構在源碼中有詳細的解釋,這個之前的文章已經說過了,這次只講,我們如何我們具體要用的方法
saver.save() 保存
源碼結構
def save(self,sess,save_path,global_step=None,latest_filename=None,meta_graph_suffix="meta",write_meta_graph=True,write_state=True):# 實際運用 : # saver = tf.train.Saver() # saver.save(sess, checkpoint_dir + 'model55.ckpt', global_step=i+1) # 注意,實際保存時 model55.ckpt 會被保存為多個文件常用的參數:
1. sess : 要保存的session
2. save_path :保存路徑,注意想要保存在代碼所在目錄下,前面不要加’/’不然會變成根目錄
3. global_step :多次迭代時,使用該參數,按照步驟保存
4. 保存文件如下,后面的-50,100,是按照步驟(global_step)保存的
調用
源碼結構
def restore(self, sess, save_path):# sess 即為 當前session # save_path : 與之前保存時的使用的名字一直 # 如果調取上一個例子存儲的模型:此時 save_path = checkpoint_dir + 'model55.ckpt' # 代碼實例 :saver.restore(sess, ckpt.model_checkpoint_path)ckpt文件
之前已經在原來的文章中寫過,這里有必要再發一次
TensorFlow模型會保存在后綴為.ckpt的文件中。保存后在save這個文件夾中會出現3個文件,因為TensorFlow會將計算圖的結構和圖上參數取值分開保存。
checkpoint文件保存了一個目錄下所有的模型文件列表,這個文件是tf.train.Saver類自動生成且自動維護的。在
checkpoint文件中維護了由一個tf.train.Saver類持久化的所有TensorFlow模型文件的文件名。當某個保存的TensorFlow模型文件被刪除時,這個模型所對應的文件名也會從checkpoint文件中刪除。checkpoint中內容的格式為CheckpointState
Protocol Buffer.
model.ckpt.meta文件保存了TensorFlow計算圖的結構,可以理解為神經網絡的網絡結構
TensorFlow通過元圖(MetaGraph)來記錄計算圖中節點的信息以及運行計算圖中節點所需要的元數據。TensorFlow中元圖是由MetaGraphDef
Protocol Buffer定義的。MetaGraphDef
中的內容構成了TensorFlow持久化時的第一個文件。保存MetaGraphDef
信息的文件默認以.meta為后綴名,文件model.ckpt.meta中存儲的就是元圖數據。
model.ckpt文件保存了TensorFlow程序中每一個變量的取值,這個文件是通過SSTable格式存儲的,可以大致理解為就是一個(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在這個文件中存儲的變量列表。列表剩下的每一行保存了一個變量的片段,變量片段的信息是通過SavedSlice
Protocol
Buffer定義的。SavedSlice類型中保存了變量的名稱、當前片段的信息以及變量取值。TensorFlow提供了tf.train.NewCheckpointReader類來查看model.ckpt文件中保存的變量信息。如何使用tf.train.NewCheckpointReader類這里不做說明,請自查。
CODE AND RUN
import tensorflow as tf import numpy as np import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'x = tf.placeholder(tf.float32, shape=[None, 1]) # 擬合 y y = 4 * x + 4w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1])) y_predict = w * x + bloss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss)isTrain = False train_steps = 100 checkpoint_steps = 50 checkpoint_dir = 'save/'saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))with tf.Session() as sess:sess.run(tf.global_variables_initializer())if isTrain:for i in range(train_steps):sess.run(train, feed_dict={x: x_data})if (i + 1) % checkpoint_steps == 0:saver.save(sess, checkpoint_dir + 'model55.ckpt', global_step=i+1)print(sess.run(w))print(sess.run(b))'''運行結果[ 3.87540483][ 4.07181311]最后訓練好的模型跑出來的數據[ 3.994277][ 4.00329876]'''else:ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)else:passprint(sess.run(w))print(sess.run(b))'''[ 3.994277][ 4.00329876]'''最后
更詳細的內容,請點擊這里
總結
以上是生活随笔為你收集整理的5.2 TensorFlow:模型的加载,存储,实例的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 卷积神经网络(cnn)的体系结构
- 下一篇: 安装Ubuntu16.04并安装sogo