5.1 Tensorflow:图与模型的加载与存储
這里寫目錄標題
- 前言
- 快速應用
- 存儲與加載,簡單示例
- 存儲的文件
- tf.train.Saver與存儲文件的講解
- 核心定義
- 存儲文件的講解
- 保存圖與模型進階
- 按迭代次數(shù)保存
- 按時間保存
- 更詳細的解釋
前言
自己學Tensorflow,現(xiàn)在看的書是《TensorFlow技術解析與實戰(zhàn)》,不得不說這書前面的部分有點坑,后面的還不清楚.圖與模型的加載寫的不清楚,書上的代碼還不能運行=- =,真是BI…咳咳.之后還是開始了查文檔,翻博客的填坑之旅
,以下為學習總結(jié).
快速應用
存儲與加載,簡單示例
# 一般而言我們是構建模型之后,session運行,但是這次不同之處在于我們是構件好之后存儲了模型 # 然后在session中加載存儲好的模型,再運行 import tensorflow as tf import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'# 聲明兩個變量 v1 = tf.Variable(tf.random_normal([1, 2]), name='v1') v2 = tf.Variable(tf.random_normal([2, 3]), name='v2') init_op = tf.global_variables_initializer() # 初始化全部變量 # saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 聲明tf.train.Saver類用于保存模型 saver = tf.train.Saver() # 只存儲圖 if not os.path.exists('save/model.meta'):saver.export_meta_graph('save/model.meta')print() with tf.Session() as sess:sess.run(init_op)print('v1:', sess.run(v1)) # 打印v1、v2的值一會讀取之后對比print('v2:', sess.run(v2))saver_path = saver.save(sess, 'save/model.ckpt') # 將模型保存到save/model.ckpt文件print('Model saved in file:', saver_path)print() with tf.Session() as sess:saver.restore(sess, 'save/model.ckpt') # 即將固化到硬盤中的模型從保存路徑再讀取出來,這樣就可以直接使用之前訓練好,或者訓練到某一階段的的模型了print('v1:', sess.run(v1)) # 打印v1、v2的值和之前的進行對比print('v2:', sess.run(v2))print('Model Restored')print() # 只加載圖, saver = tf.train.import_meta_graph('save/model.ckpt.meta') with tf.Session() as sess:saver.restore(sess, 'save/model.ckpt')# 通過張量的名稱來獲取張量,也可以直接運行新的張量print('v1:', sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))print('v2:', sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))運行結(jié)果:
v1: [[-0.78213912 -0.72646964]] v2: [[-0.36301413 -0.99892306 0.21593148][-1.09692276 -0.06931346 0.19474344]] Model saved in file: save/model.ckptv1: [[-0.78213912 -0.72646964]] v2: [[-0.36301413 -0.99892306 0.21593148][-1.09692276 -0.06931346 0.19474344]] Model Restoredv1: [[-0.78213912 -0.72646964]] v2: [[-0.36301413 -0.99892306 0.21593148][-1.09692276 -0.06931346 0.19474344]]構建模型后直接運行的結(jié)果,與加載存儲的模型,加載存儲的圖,并哪找張量的名稱獲取張量并運行的結(jié)果是一致的
存儲的文件
tf.train.Saver與存儲文件的講解
核心定義
主要類:tf.train.Saver類負責保存和還原神經(jīng)網(wǎng)絡
自動保存為三個文件:模型文件列表checkpoint,計算圖結(jié)構model.ckpt.meta,每個變量的取值model.ckpt。其中前兩個自動生成。
加載持久化圖:通過tf.train.import_meta_graph(“save/model.ckpt.meta”)加載持久化的圖
存儲文件的講解
這段代碼中,通過saver.save函數(shù)將TensorFlow模型保存到了model/model.ckpt文件中,這里代碼中指定路徑為"save/model.ckpt",也就是保存到了當前程序所在文件夾里面的save文件夾中。
TensorFlow模型會保存在后綴為.ckpt的文件中。保存后在save這個文件夾中會出現(xiàn)3個文件,因為TensorFlow會將計算圖的結(jié)構和圖上參數(shù)取值分開保存。
checkpoint文件保存了一個目錄下所有的模型文件列表,這個文件是tf.train.Saver類自動生成且自動維護的。在
checkpoint文件中維護了由一個tf.train.Saver類持久化的所有TensorFlow模型文件的文件名。當某個保存的TensorFlow模型文件被刪除時,這個模型所對應的文件名也會從checkpoint文件中刪除。checkpoint中內(nèi)容的格式為CheckpointState
Protocol Buffer.
model.ckpt.meta文件保存了TensorFlow計算圖的結(jié)構,可以理解為神經(jīng)網(wǎng)絡的網(wǎng)絡結(jié)構
TensorFlow通過元圖(MetaGraph)來記錄計算圖中節(jié)點的信息以及運行計算圖中節(jié)點所需要的元數(shù)據(jù)。TensorFlow中元圖是由MetaGraphDef
Protocol Buffer定義的。MetaGraphDef
中的內(nèi)容構成了TensorFlow持久化時的第一個文件。保存MetaGraphDef
信息的文件默認以.meta為后綴名,文件model.ckpt.meta中存儲的就是元圖數(shù)據(jù)。
model.ckpt文件保存了TensorFlow程序中每一個變量的取值,這個文件是通過SSTable格式存儲的,可以大致理解為就是一個(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在這個文件中存儲的變量列表。列表剩下的每一行保存了一個變量的片段,變量片段的信息是通過SavedSlice
Protocol
Buffer定義的。SavedSlice類型中保存了變量的名稱、當前片段的信息以及變量取值。TensorFlow提供了tf.train.NewCheckpointReader類來查看model.ckpt文件中保存的變量信息。如何使用tf.train.NewCheckpointReader類這里不做說明,請自查。
保存圖與模型進階
按迭代次數(shù)保存
# 在1000次迭代時存儲 saver.save(sess, 'my_test_model',global_step=1000)運行結(jié)果:
my_test_model-1000.index my_test_model-1000.meta my_test_model-1000.data-00000-of-00001 checkpoint按時間保存
#saves a model every 2 hours and maximum 4 latest models are saved. saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)更詳細的解釋
其實更詳細的解釋就在源碼之中,這些英語還是簡單,我相信以大家的水平應該都能看得懂。就不侮辱大家的智商。
def __init__(self,var_list=None,reshape=False,sharded=False,max_to_keep=5,keep_checkpoint_every_n_hours=10000.0,# 默認時間是一萬小時,有趣# 但我們只爭朝夕name=None,restore_sequentially=False,saver_def=None,builder=None,defer_build=False,allow_empty=False,write_version=saver_pb2.SaverDef.V2,pad_step_number=False,save_relative_paths=False):"""Creates a `Saver`.The constructor adds ops to save and restore variables.`var_list` specifies the variables that will be saved and restored. It canbe passed as a `dict` or a list:* A `dict` of names to variables: The keys are the names that will beused to save or restore the variables in the checkpoint files.* A list of variables: The variables will be keyed with their op name inthe checkpoint files.For example:```pythonv1 = tf.Variable(..., name='v1')v2 = tf.Variable(..., name='v2')# Pass the variables as a dict:saver = tf.train.Saver({'v1': v1, 'v2': v2})# Or pass them as a list.saver = tf.train.Saver([v1, v2])# Passing a list is equivalent to passing a dict with the variable op names# as keys:saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})```The optional `reshape` argument, if `True`, allows restoring a variable froma save file where the variable had a different shape, but the same numberof elements and type. This is useful if you have reshaped a variable andwant to reload it from an older checkpoint.The optional `sharded` argument, if `True`, instructs the saver to shardcheckpoints per device.Args:var_list: A list of `Variable`/`SaveableObject`, or a dictionary mappingnames to `SaveableObject`s. If `None`, defaults to the list of allsaveable objects.reshape: If `True`, allows restoring parameters from a checkpointwhere the variables have a different shape.sharded: If `True`, shard the checkpoints, one per device.max_to_keep: Maximum number of recent checkpoints to keep.Defaults to 5.keep_checkpoint_every_n_hours: How often to keep checkpoints.Defaults to 10,000 hours.name: String. Optional name to use as a prefix when adding operations.restore_sequentially: A `Bool`, which if true, causes restore of differentvariables to happen sequentially within each device. This can lowermemory usage when restoring very large models.saver_def: Optional `SaverDef` proto to use instead of running thebuilder. This is only useful for specialty code that wants to recreatea `Saver` object for a previously built `Graph` that had a `Saver`.The `saver_def` proto should be the one returned by the`as_saver_def()` call of the `Saver` that was created for that `Graph`.builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.Defaults to `BaseSaverBuilder()`.defer_build: If `True`, defer adding the save and restore ops to the`build()` call. In that case `build()` should be called beforefinalizing the graph or using the saver.allow_empty: If `False` (default) raise an error if there are novariables in the graph. Otherwise, construct the saver anyway and makeit a no-op.write_version: controls what format to use when saving checkpoints. Italso affects certain filepath matching logic. The V2 format is therecommended choice: it is much more optimized than V1 in terms ofmemory required and latency incurred during restore. Regardless ofthis flag, the Saver is able to restore from both V2 and V1 checkpoints.pad_step_number: if True, pads the global step number in the checkpointfilepaths to some fixed width (8 by default). This is turned off bydefault.save_relative_paths: If `True`, will write relative paths to thecheckpoint state file. This is needed if the user wants to copy thecheckpoint directory and reload from the copied directory.Raises:TypeError: If `var_list` is invalid.ValueError: If any of the keys or values in `var_list` are not unique."""總結(jié)
以上是生活随笔為你收集整理的5.1 Tensorflow:图与模型的加载与存储的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 3.1 Tensorflow: 批标准化
- 下一篇: 6.1 Tensorflow笔记(基础篇