Tensorflow入门——训练结果的保存与加载
2019獨角獸企業重金招聘Python工程師標準>>>
訓練完成以后我們就可以直接使用訓練好的模板進行預測了
但是每次在預測之前都要進行訓練,不是一個常規操作,畢竟有些復雜的模型需要訓練好幾天甚至更久
所以將訓練好的模型進行保存,當有需要的時候重新加載這個模型進行預測或者繼續訓練,這才是一個常規操作
我們依然使用最簡單的例子進行說明,這里沿用Tensorflow入門——實現最簡單的線性回歸模型的預測?這個例子進行
====================================================
模型的保存
在tensorflow中保存模型使用的是tf.train.Saver對象,我們需要在保存之前先實例化這個對象
saver = tf.train.Saver()對于模型的保存,其實就是保存整個session對象,再給定一個path就實現了模型的保存(對應的path需要存在,如果不存在會報錯)
saver.save(sess, SAVE_PATH + 'model')保存完成以后,可以看到對應的目錄下面生成了4個文件
model.meta中保存的是模型,而這個模型僅僅是計算流和參數的定義,可以認為是一個未經訓練的模型
model.index和model.data-00000-of-00001中保存的是參數值,也就是真正訓練的結果
checkpoint中保存的是最后幾次保存的信息,從文件名就可以看出它是一個檢查點,記錄了其他幾個文件之間的關系,這是一個txt文件,我們可以打開看一下(在這個例子中我們只保存了一次,如果保存多次的話這個文件中會記錄多次保存結果的信息)
下面是運行的log
epoch= 0 _loss= 6029.333 _w= [0.005] _n= [0.005] epoch= 5000 _loss= 10.897877 _w= [4.2031364] _n= [-1.905781] epoch= 10000 _loss= 112.455055 _w= [4.7837024] _n= [-11.81817] epoch= 15000 _loss= 6.2376847 _w= [5.1548934] _n= [-19.740992] epoch= 20000 _loss= 2.9357195 _w= [5.2787647] _n= [-22.662355] epoch= 25000 _loss= 0.022824269 _w= [5.3112087] _n= [-23.141117] epoch= 30000 _loss= 1.3711997 _w= [5.326612] _n= [-23.255548] epoch= 35000 _loss= 0.005477888 _w= [5.3088646] _n= [-23.289743] epoch= 40000 _loss= 2.8727396 _w= [5.315157] _n= [-23.191956] epoch= 45000 _loss= 0.009563584 _w= [5.300157] _n= [-23.18857] 訓練完成,開始預測。。。 x= 0.1610020536371326 y預測= [-22.44688] y實際= -22.401859054114084 x= 7.379937860774309 y預測= [16.030691] y實際= 16.075068797927063 x= 5.1744928042152685 y預測= [4.2754745] y實際= 4.320046646467379 x= 10.26990231423617 y預測= [31.434462] y實際= 31.478579334878784 x= 23.219346463697207 y預測= [100.45616] y實際= 100.49911665150611 x= 7.101197776563807 y預測= [14.544985] y實際= 14.589384149085088 x= 3.097841295090581 y預測= [-6.7932644] y實際= -6.7485058971672025 x= 6.474682013005717 y預測= [11.205599] y實際= 11.250055129320469 x= 13.811264369891983 y預測= [50.310234] y實際= 50.35403909152427 x= 29.260954830177415 y預測= [132.65846] y實際= 132.70088924484563====================================================
模型的加載
因為保存時分成了模型和參數值兩部分進行保存,所以在加載模型的時候也需要將模型和參數值(訓練結果)兩步分開進行加載
上面講到了meta文件是模型,checkpoint是參數值,這里分別使用tf.train下的import_meta_graph和latest_checkpoint方法來加載
saver = tf.train.import_meta_graph(SAVE_PATH + 'model.meta') saver.restore(sess, tf.train.latest_checkpoint(SAVE_PATH))這樣,之前保存起來的模型就被我們重新加載成功了,但是在預測或者繼續訓練之前,我們需要重新定義相關的變量
但是也不是憑空的重新定義,因為這些參數已經在之前保存的模型中定義過了,我們只需要從已經加載的模型中將相關參數的定義給找出來就可以了
為了找回參數的定義,我們需要稍微修改一下模型,將這些需要在重新加載階段找回的參數定義給上命名(如果是用來預測,我們需要找回X和OUT,如果是用來繼續訓練,我們需要找回X、OUT、loss),所以這里我們將模型中相關的參數都給上命名
X = tf.placeholder(tf.float32, name='X') Y = tf.placeholder(tf.float32, name='Y')W = tf.Variable(tf.zeros([1]), name='W') B = tf.Variable(tf.zeros([1]), name='B') OUT = tf.add(tf.multiply(X, W), B, name='OUT')loss = tf.reduce_mean(tf.square(Y - OUT), name='loss') optimizer = tf.train.AdamOptimizer(0.005).minimize(loss)在找回參數之前,需要獲取計算圖對象(關于計算圖的概念,現在可以不必先了解)
graph = tf.get_default_graph()然后通過get_all_collection_keys,來查看這個模型中的內容
print(graph.get_all_collection_keys())可以看到一共有三項,分別是train_op:優化器,trainable_variables:可訓練的變量,variables:所有變量
['train_op', 'trainable_variables', 'variables']我們再通過get_collection方法把這些對象也打印出來看一下
print(graph.get_collection('train_op')) print(graph.get_collection('trainable_variables')) print(graph.get_collection('variables'))但是從中發現,我們需要找回的參數都不在這里
[<tf.Operation 'Adam' type=NoOp>] [<tf.Variable 'W:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B:0' shape=(1,) dtype=float32_ref>] [<tf.Variable 'W:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'beta1_power:0' shape=() dtype=float32_ref>, <tf.Variable 'beta2_power:0' shape=() dtype=float32_ref>, <tf.Variable 'W/Adam:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'W/Adam_1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B/Adam:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B/Adam_1:0' shape=(1,) dtype=float32_ref>]繼續通過get_operations方法來查看所有的操作數
print(graph.get_operations())從以下內容中我們發現了需要找回的參數X、Y、OUT等
[<tf.Operation 'X' type=Placeholder>, <tf.Operation 'Y' type=Placeholder>, <tf.Operation 'zeros' type=Const>, <tf.Operation 'W' type=VariableV2>, <tf.Operation 'W/Assign' type=Assign>, <tf.Operation 'W/read' type=Identity>, <tf.Operation 'zeros_1' type=Const>, <tf.Operation 'B' type=VariableV2>, <tf.Operation 'B/Assign' type=Assign>, <tf.Operation 'B/read' type=Identity>, <tf.Operation 'Mul' type=Mul>, <tf.Operation 'OUT' type=Add>, <tf.Operation 'sub' type=Sub>, <tf.Operation 'Square' type=Square>, <tf.Operation 'Rank' type=Rank>, <tf.Operation 'range/start' type=Const>, <tf.Operation 'range/delta' type=Const>, <tf.Operation 'range' type=Range>, <tf.Operation 'loss' type=Mean>, <tf.Operation 'gradients/Shape' type=Const>, <tf.Operation 'gradients/grad_ys_0' type=Const>, <tf.Operation 'gradients/Fill' type=Fill>, <tf.Operation 'gradients/loss_grad/Shape' type=Shape>, <tf.Operation 'gradients/loss_grad/Size' type=Size>, <tf.Operation 'gradients/loss_grad/add' type=Add>, <tf.Operation 'gradients/loss_grad/mod' type=FloorMod>, <tf.Operation 'gradients/loss_grad/Shape_1' type=Shape>, <tf.Operation 'gradients/loss_grad/range/start' type=Const>, <tf.Operation 'gradients/loss_grad/range/delta' type=Const>, <tf.Operation 'gradients/loss_grad/range' type=Range>, <tf.Operation 'gradients/loss_grad/Fill/value' type=Const>, <tf.Operation 'gradients/loss_grad/Fill' type=Fill>, <tf.Operation 'gradients/loss_grad/DynamicStitch' type=DynamicStitch>, <tf.Operation 'gradients/loss_grad/Maximum/y' type=Const>, <tf.Operation 'gradients/loss_grad/Maximum' type=Maximum>, <tf.Operation 'gradients/loss_grad/floordiv' type=FloorDiv>, <tf.Operation 'gradients/loss_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/loss_grad/Tile' type=Tile>, <tf.Operation 'gradients/loss_grad/Shape_2' type=Shape>, <tf.Operation 'gradients/loss_grad/Shape_3' type=Const>, <tf.Operation 'gradients/loss_grad/Const' type=Const>, <tf.Operation 'gradients/loss_grad/Prod' type=Prod>, <tf.Operation 'gradients/loss_grad/Const_1' type=Const>, <tf.Operation 'gradients/loss_grad/Prod_1' type=Prod>, <tf.Operation 'gradients/loss_grad/Maximum_1/y' type=Const>, <tf.Operation 'gradients/loss_grad/Maximum_1' type=Maximum>, <tf.Operation 'gradients/loss_grad/floordiv_1' type=FloorDiv>, <tf.Operation 'gradients/loss_grad/Cast' type=Cast>, <tf.Operation 'gradients/loss_grad/truediv' type=RealDiv>, <tf.Operation 'gradients/Square_grad/Const' type=Const>, <tf.Operation 'gradients/Square_grad/Mul' type=Mul>, <tf.Operation 'gradients/Square_grad/Mul_1' type=Mul>, <tf.Operation 'gradients/sub_grad/Shape' type=Shape>, <tf.Operation 'gradients/sub_grad/Shape_1' type=Shape>, <tf.Operation 'gradients/sub_grad/BroadcastGradientArgs' type=BroadcastGradientArgs>, <tf.Operation 'gradients/sub_grad/Sum' type=Sum>, <tf.Operation 'gradients/sub_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/sub_grad/Sum_1' type=Sum>, <tf.Operation 'gradients/sub_grad/Neg' type=Neg>, <tf.Operation 'gradients/sub_grad/Reshape_1' type=Reshape>, <tf.Operation 'gradients/sub_grad/tuple/group_deps' type=NoOp>, <tf.Operation 'gradients/sub_grad/tuple/control_dependency' type=Identity>, <tf.Operation 'gradients/sub_grad/tuple/control_dependency_1' type=Identity>, <tf.Operation 'gradients/OUT_grad/Shape' type=Shape>, <tf.Operation 'gradients/OUT_grad/Shape_1' type=Const>, <tf.Operation 'gradients/OUT_grad/BroadcastGradientArgs' type=BroadcastGradientArgs>, <tf.Operation 'gradients/OUT_grad/Sum' type=Sum>, <tf.Operation 'gradients/OUT_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/OUT_grad/Sum_1' type=Sum>, <tf.Operation 'gradients/OUT_grad/Reshape_1' type=Reshape>, <tf.Operation 'gradients/OUT_grad/tuple/group_deps' type=NoOp>, <tf.Operation 'gradients/OUT_grad/tuple/control_dependency' type=Identity>, <tf.Operation 'gradients/OUT_grad/tuple/control_dependency_1' type=Identity>, <tf.Operation 'gradients/Mul_grad/Shape' type=Shape>, <tf.Operation 'gradients/Mul_grad/Shape_1' type=Const>, <tf.Operation 'gradients/Mul_grad/BroadcastGradientArgs' type=BroadcastGradientArgs>, <tf.Operation 'gradients/Mul_grad/Mul' type=Mul>, <tf.Operation 'gradients/Mul_grad/Sum' type=Sum>, <tf.Operation 'gradients/Mul_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/Mul_grad/Mul_1' type=Mul>, <tf.Operation 'gradients/Mul_grad/Sum_1' type=Sum>, <tf.Operation 'gradients/Mul_grad/Reshape_1' type=Reshape>, <tf.Operation 'gradients/Mul_grad/tuple/group_deps' type=NoOp>, <tf.Operation 'gradients/Mul_grad/tuple/control_dependency' type=Identity>, <tf.Operation 'gradients/Mul_grad/tuple/control_dependency_1' type=Identity>, <tf.Operation 'beta1_power/initial_value' type=Const>, <tf.Operation 'beta1_power' type=VariableV2>, <tf.Operation 'beta1_power/Assign' type=Assign>, <tf.Operation 'beta1_power/read' type=Identity>, <tf.Operation 'beta2_power/initial_value' type=Const>, <tf.Operation 'beta2_power' type=VariableV2>, <tf.Operation 'beta2_power/Assign' type=Assign>, <tf.Operation 'beta2_power/read' type=Identity>, <tf.Operation 'W/Adam/Initializer/zeros' type=Const>, <tf.Operation 'W/Adam' type=VariableV2>, <tf.Operation 'W/Adam/Assign' type=Assign>, <tf.Operation 'W/Adam/read' type=Identity>, <tf.Operation 'W/Adam_1/Initializer/zeros' type=Const>, <tf.Operation 'W/Adam_1' type=VariableV2>, <tf.Operation 'W/Adam_1/Assign' type=Assign>, <tf.Operation 'W/Adam_1/read' type=Identity>, <tf.Operation 'B/Adam/Initializer/zeros' type=Const>, <tf.Operation 'B/Adam' type=VariableV2>, <tf.Operation 'B/Adam/Assign' type=Assign>, <tf.Operation 'B/Adam/read' type=Identity>, <tf.Operation 'B/Adam_1/Initializer/zeros' type=Const>, <tf.Operation 'B/Adam_1' type=VariableV2>, <tf.Operation 'B/Adam_1/Assign' type=Assign>, <tf.Operation 'B/Adam_1/read' type=Identity>, <tf.Operation 'Adam/learning_rate' type=Const>, <tf.Operation 'Adam/beta1' type=Const>, <tf.Operation 'Adam/beta2' type=Const>, <tf.Operation 'Adam/epsilon' type=Const>, <tf.Operation 'Adam/update_W/ApplyAdam' type=ApplyAdam>, <tf.Operation 'Adam/update_B/ApplyAdam' type=ApplyAdam>, <tf.Operation 'Adam/mul' type=Mul>, <tf.Operation 'Adam/Assign' type=Assign>, <tf.Operation 'Adam/mul_1' type=Mul>, <tf.Operation 'Adam/Assign_1' type=Assign>, <tf.Operation 'Adam' type=NoOp>, <tf.Operation 'init' type=NoOp>, <tf.Operation 'save/filename/input' type=Const>, <tf.Operation 'save/filename' type=PlaceholderWithDefault>, <tf.Operation 'save/Const' type=PlaceholderWithDefault>, <tf.Operation 'save/SaveV2/tensor_names' type=Const>, <tf.Operation 'save/SaveV2/shape_and_slices' type=Const>, <tf.Operation 'save/SaveV2' type=SaveV2>, <tf.Operation 'save/control_dependency' type=Identity>, <tf.Operation 'save/RestoreV2/tensor_names' type=Const>, <tf.Operation 'save/RestoreV2/shape_and_slices' type=Const>, <tf.Operation 'save/RestoreV2' type=RestoreV2>, <tf.Operation 'save/Assign' type=Assign>, <tf.Operation 'save/Assign_1' type=Assign>, <tf.Operation 'save/Assign_2' type=Assign>, <tf.Operation 'save/Assign_3' type=Assign>, <tf.Operation 'save/Assign_4' type=Assign>, <tf.Operation 'save/Assign_5' type=Assign>, <tf.Operation 'save/Assign_6' type=Assign>, <tf.Operation 'save/Assign_7' type=Assign>, <tf.Operation 'save/restore_all' type=NoOp>]恢復參數:
這里需要注意的是在后面需要加上“:0”,代表第0個參數(這個涉及到另一個概念,以后再細講)
X = graph.get_tensor_by_name('X:0') Y = graph.get_tensor_by_name('Y:0') W = graph.get_tensor_by_name('W:0') B = graph.get_tensor_by_name('B:0') OUT = graph.get_tensor_by_name('OUT:0') loss = graph.get_tensor_by_name('loss:0')恢復優化器:
optimizer = graph.get_collection('train_op')仍然將之前代碼中的預測和訓練相關的邏輯拷過來執行一下
Instructions for updating: Use standard file APIs to check for files with this prefix. INFO:tensorflow:Restoring parameters from D:/test/tf1/xw_b/model 重新加載,開始預測。。。 x= 26.764991404677083 y預測= [[119.67893]] y實際= 119.39740418692885 x= 25.85141169466281 y預測= [[114.797356]] y實際= 114.52802433255279 x= 17.046457082367727 y預測= [[67.749466]] y實際= 67.59761624901998 x= 5.918111849660451 y預測= [[8.286896]] y實際= 8.283536158690204 x= 7.409698341670607 y預測= [[16.256956]] y實際= 16.233692161104333 x= 15.469762867798304 y預測= [[59.324646]] y實際= 59.19383608536495 x= 11.519144276233455 y預測= [[38.215134]] y實際= 38.13703899232431 x= 27.85137286496477 y預測= [[125.48383]] y實際= 125.18781737026221 x= 26.50150532742774 y預測= [[118.271034]] y實際= 117.99302339518984 x= 15.664275922154658 y預測= [[60.364]] y實際= 60.23059066508432 繼續訓練 epoch= 0 _loss= 16.00476 _w= [5.3422985] _n= [-23.3365] epoch= 5000 _loss= 19.420956 _w= [5.3203373] _n= [-23.186474] epoch= 10000 _loss= 0.30325127 _w= [5.3471537] _n= [-23.290209] epoch= 15000 _loss= 3.018042 _w= [5.32293] _n= [-23.245607] epoch= 20000 _loss= 12.473472 _w= [5.309146] _n= [-23.24814] epoch= 25000 _loss= 17.09799 _w= [5.3170156] _n= [-23.342768] epoch= 30000 _loss= 18.25596 _w= [5.3193855] _n= [-23.225794] epoch= 35000 _loss= 0.32235628 _w= [5.339825] _n= [-23.196495] epoch= 40000 _loss= 2.6598516 _w= [5.304051] _n= [-23.248428] epoch= 45000 _loss= 6.564373 _w= [5.328891] _n= [-23.212101] 繼續訓練完成,開始預測。。。 x= 24.14983880390778 y預測= [[105.329315]] y實際= 105.45864082482846 x= 8.654129156050717 y預測= [[22.795414]] y實際= 22.86650840175032 x= 17.410606725772045 y預測= [[69.434525]] y實際= 69.53853384836499 x= 17.55599000188004 y預測= [[70.20888]] y實際= 70.31342671002061 x= 24.43148021367975 y預測= [[106.82939]] y實際= 106.95978953891309 x= 20.286380740475614 y預測= [[84.751595]] y實際= 84.86640934673503 x= 2.8131286438423353 y預測= [[-8.3151655]] y實際= -8.266024328320354 x= 11.781139561484927 y預測= [[39.450626]] y實際= 39.53347386271466 x= 4.611147529065006 y預測= [[1.2615166]] y實際= 1.3174163299164796 x= 6.625783852577516 y預測= [[11.991955]] y實際= 12.055427934238164使用恢復以后的模型直接進行預測,匹配程度也非常高,而進行繼續訓練也沒問題
====================================================
完整代碼如下,在python3.6.8、tensorflow1.13環境下成功運行
https://github.com/yukiti2007/sample/blob/master/python/tensorflow/wx_b_save.py
import randomimport tensorflow as tfSAVE_PATH = "D:/test/tf1/xw_b/"def create_data(for_train=False):w = 5.33b = -23.26x = random.random() * 30y = w * x + bif for_train:noise = (random.random() - 0.5) * 10y += noisereturn x, ydef train():X = tf.placeholder(tf.float32, name='X')Y = tf.placeholder(tf.float32, name='Y')W = tf.Variable(tf.zeros([1]), name='W')B = tf.Variable(tf.zeros([1]), name='B')OUT = tf.add(tf.multiply(X, W), B, name='OUT')loss = tf.reduce_mean(tf.square(Y - OUT), name='loss')optimizer = tf.train.AdamOptimizer(0.005).minimize(loss)with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(50000):x_data, y_data = create_data(True)_, _loss, _w, _b = sess.run([optimizer, loss, W, B], feed_dict={X: x_data, Y: y_data})if 0 == epoch % 5000:print("epoch=", epoch, "_loss=", _loss, "_w=", _w, "_n=", _b)print("訓練完成,開始預測。。。")for step in range(10):x_data, y_data = create_data(False)prediction_value = sess.run(OUT, feed_dict={X: x_data})print("x=", x_data, "y預測=", prediction_value, "y實際=", y_data)saver = tf.train.Saver()saver.save(sess, SAVE_PATH + 'model')def predict():sess = tf.Session()saver = tf.train.import_meta_graph(SAVE_PATH + 'model.meta')saver.restore(sess, tf.train.latest_checkpoint(SAVE_PATH))graph = tf.get_default_graph()X = graph.get_tensor_by_name('X:0')Y = graph.get_tensor_by_name('Y:0')W = graph.get_tensor_by_name('W:0')B = graph.get_tensor_by_name('B:0')OUT = graph.get_tensor_by_name('OUT:0')loss = graph.get_tensor_by_name('loss:0')optimizer = graph.get_collection('train_op')# print(graph.get_all_collection_keys())# print(graph.get_collection('train_op'))# print(graph.get_collection('trainable_variables'))# print(graph.get_collection('variables'))# print(graph.get_operations())print("重新加載,開始預測。。。")for step in range(10):x_data, y_data = create_data(False)prediction_value = sess.run(OUT, feed_dict={X: [[x_data]]})print("x=", x_data, "y預測=", prediction_value, "y實際=", y_data)print("繼續訓練")for epoch in range(50000):x_data, y_data = create_data(True)_, _loss, _w, _b = sess.run([optimizer, loss, W, B], feed_dict={X: x_data, Y: y_data})if 0 == epoch % 5000:print("epoch=", epoch, "_loss=", _loss, "_w=", _w, "_n=", _b)print("繼續訓練完成,開始預測。。。")for step in range(10):x_data, y_data = create_data(False)prediction_value = sess.run(OUT, feed_dict={X: [[x_data]]})print("x=", x_data, "y預測=", prediction_value, "y實際=", y_data)if __name__ == "__main__":train()predict()轉載于:https://my.oschina.net/u/4105485/blog/3034104
總結
以上是生活随笔為你收集整理的Tensorflow入门——训练结果的保存与加载的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Spring event 使用完全指南
- 下一篇: 基于ROS的人脸识别