TensorFlow保存和载入训练模型
生活随笔
收集整理的這篇文章主要介紹了
TensorFlow保存和载入训练模型
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
保存:使用saver.save()方法保存
載入:使用saver.restore()方法載入
下面是個(gè)完整例子:
保存:
import tensorflow as tfW = tf.Variable([[1, 1, 1], [2, 2, 2]], dtype=tf.float32, name='w') b = tf.Variable([[0, 1, 2]], dtype=tf.float32, name='b')saver = tf.train.Saver() with tf.Session() as sess:sess.run(tf.global_variables_initializer())save_path = saver.save(sess, r"D:\test\wb") # 將W、b保存到指定位置載入:?
import tensorflow as tfW = tf.Variable(tf.truncated_normal(shape=(2, 3)), dtype=tf.float32, name='w') b = tf.Variable(tf.truncated_normal(shape=(1, 3)), dtype=tf.float32, name='b')saver = tf.train.Saver() with tf.Session() as sess:saver.restore(sess, r"D:\test\wb") # 從指定位置加載模型print(sess.run(W))print(sess.run(b)) """ 輸出: [[1. 1. 1.][2. 2. 2.]][[0. 1. 2.]] """就算W和b定義了不同于模型的值,但是仍會輸出載入模型的值,如:
import tensorflow as tfW = tf.Variable([[0,0,0],[0,0,0]],dtype = tf.float32,name='w') b = tf.Variable([[0,0,0]],dtype = tf.float32,name='b')saver = tf.train.Saver() with tf.Session() as sess:saver.restore(sess, r"D:\test\wb")print(sess.run(W))print(sess.run(b)) """ 輸出: [[1. 1. 1.][2. 2. 2.]][[0. 1. 2.]] """這種方法不方便的在于,在使用模型的時(shí)候,必須把模型的結(jié)構(gòu)重新定義一遍,然后載入對應(yīng)名字的變量的值。
總結(jié)
以上是生活随笔為你收集整理的TensorFlow保存和载入训练模型的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python matplotlib画图是
- 下一篇: tf.while_loop