TensorFlow 加载多个模型的方法
采用 TensorFlow 的時候,有時候我們需要加載的不止是一個模型,那么如何加載多個模型呢?
原文:https://bretahajek.com/2017/04/importing-multiple-tensorflow-models-graphs/
關于 TensorFlow 可以有很多東西可以說。但這次我只介紹如何導入訓練好的模型(圖),因為我做不到導入第二個模型并將它和第一個模型一起使用。并且,這種導入非常慢,我也不想重復做第二次。另一方面,將一切東西都放到一個模型也不實際。
在這個教程中,我會介紹如何保存和載入模型,更進一步,如何加載多個模型。
加載 TensorFlow 模型
在介紹加載多個模型之前,我們先介紹下如何加載單個模型,官方文檔:https://www.tensorflow.org/programmers_guide/meta_graph。
首先,我們需要創建一個模型,訓練并保存它。這部分我不想過多介紹細節,只需要關注如何保存模型以及不要忘記給每個操作命名。
創建一個模型,訓練并保存的代碼如下:
import tensorflow as tf ### Linear Regression 線性回歸### # Input placeholders x = tf.placeholder(tf.float32, name='x') y = tf.placeholder(tf.float32, name='y') # Model parameters 定義模型的權值參數 W1 = tf.Variable([0.1], tf.float32) W2 = tf.Variable([0.1], tf.float32) W3 = tf.Variable([0.1], tf.float32) b = tf.Variable([0.1], tf.float32)# Output 模型的輸出 linear_model = tf.identity(W1 * x + W2 * x**2 + W3 * x**3 + b,name='activation_opt')# Loss 定義損失函數 loss = tf.reduce_sum(tf.square(linear_model - y), name='loss') # Optimizer and training step 定義優化器運算 optimizer = tf.train.AdamOptimizer(0.001) train = optimizer.minimize(loss, name='train_step')# Remember output operation for later aplication # Adding it to a collections for easy acces # This is not required if you NAME your output operation # 記得將輸出操作添加到一個集合中,但如何你命名了輸出操作,這一步可以省略 tf.add_to_collection("activation", linear_model)## Start the session ## sess = tf.Session() sess.run(tf.global_variables_initializer()) # CREATE SAVER saver = tf.train.Saver()# Training loop 訓練 for i in range(10000):sess.run(train, {x: data, y: expected})if i % 1000 == 0:# You can also save checkpoints using global_step variablesaver.save(sess, "models/model_name", global_step=i)# SAVE TensorFlow graph into path models/model_name # 保存模型到指定路徑并命名模型文件名字 saver.save(sess, "models/model_name")注意,這里是第一個重點–對變量和運算命名。這是為了在加載模型后可以使用指定的一些權值參數,如果不命名的話,這些變量會自動命名為類似“Placeholder_1”的名字。在復雜點的模型中,使用領域(scopes)是一個很好的做法,但這里不做展開。
總之,重點就是為了在加載模型的時候能夠調用權值參數或者某些運算操作,你必須給他們命名或者是放到一個集合中。
當保存模型后,在指定保存模型的文件夾中就應該包含這些文件:model_name.index、model_name.meta以及其他文件。如果是采用checkpoints后綴命名模型名字,還會有名字包含model_name-1000的文件,其中的數字是對應變量global_step,也就是當前訓練迭代次數。
現在我們就可以開始加載模型了。加載模型其實很簡單,我們需要的只是兩個函數即可:tf.train.import_meta_graph和saver.restore()。此外,就是提供正確的模型保存路徑位置。另外,如果我們希望在不同機器使用模型,那么還需要設置參數:clear_device=True。
接著,我們就可以通過之前命名的名字或者是保存到的集合名字來調用保存的運算或者是權值參數了。如果使用了領域,那么還需要包含領域的名字才行。而在實際調用這些運算的時候,還必須采用類似{'PlaceholderName:0': data}的輸入占位符,否則會出現錯誤。
加載模型的代碼如下:
sess = tf.Session()# Import graph from the path and recover session # 加載模型并恢復到會話中 saver = tf.train.import_meta_graph('models/model_name.meta', clear_devices=True) saver.restore(sess, 'models/model_name')# There are TWO options how to access the operation (choose one) # 兩種方法來調用指定的運算操作,選擇其中一個都可以# FROM SAVED COLLECTION: 從保存的集合中調用 activation = tf.get_collection('activation')[0]# BY NAME: 采用命名的方式 activation = tf.get_default_graph.get_operation_by_name('activation_opt').outputs[0]# Use imported graph for data # You have to feed data as {'x:0': data} # Don't forget on ':0' part! # 采用加載的模型進行操作,不要忘記輸入占位符 data = 50 result = sess.run(activation, {'x:0': data}) print(result)多個模型
上述介紹了如何加載單個模型的操作,但如何加載多個模型呢?
如果使用加載單個模型的方式去加載多個模型,那么就會出現變量沖突的錯誤,也無法工作。這個問題的原因是因為一個默認圖的緣故。沖突的發生是因為我們將所有變量都加載到當前會話采用的默認圖中。當我們采用會話的時候,我們可以通過tf.Session(graph=MyGraph)來指定采用不同的已經創建好的圖。因此,如果我們希望加載多個模型,那么我們需要做的就是把他們加載在不同的圖,然后在不同會話中使用它們。
這里,自定義一個類來完成加載指定路徑的模型到一個局部圖的操作。這個類還提供run函數來對輸入數據使用加載的模型進行操作。這個類對于我是有用的,因為我總是將模型輸出放到一個集合或者對它命名為activation_opt,并且將輸入占位符命名為x。你可以根據自己實際應用需求對這個類進行修改和拓展。
代碼如下:
import tensorflow as tfclass ImportGraph():""" Importing and running isolated TF graph """def __init__(self, loc):# Create local graph and use it in the sessionself.graph = tf.Graph()self.sess = tf.Session(graph=self.graph)with self.graph.as_default():# Import saved model from location 'loc' into local graph# 從指定路徑加載模型到局部圖中saver = tf.train.import_meta_graph(loc + '.meta',clear_devices=True)saver.restore(self.sess, loc)# There are TWO options how to get activation operation:# 兩種方式來調用運算或者參數# FROM SAVED COLLECTION: self.activation = tf.get_collection('activation')[0]# BY NAME:self.activation = self.graph.get_operation_by_name('activation_opt').outputs[0]def run(self, data):""" Running the activation operation previously imported """# The 'x' corresponds to name of input placeholderreturn self.sess.run(self.activation, feed_dict={"x:0": data})### Using the class ### # 測試樣例 data = 50 # random data model = ImportGraph('models/model_name') result = model.run(data) print(result)總結
如果你理解了 TensorFlow 的機制的話,加載多個模型并不是一件困難的事情。上述的解決方法可能不是完美的,但是它簡單且快速。最后給出總結整個過程的樣例代碼,這是在 Jupyter notebook 上的,代碼地址如下:
https://gist.github.com/Breta01/f205a9d27090c18d394fbaab98de7c65#file-importmodulesnotebook-ipynb
最后,給出文章中幾個代碼例子的 github 地址:
歡迎關注我的微信公眾號–機器學習與計算機視覺或者掃描下方的二維碼,在后臺留言,和我分享你的建議和看法,指正文章中可能存在的錯誤,大家一起交流,學習和進步!
推薦閱讀
1.機器學習入門系列(1)–機器學習概覽(上)
2.機器學習入門系列(2)–機器學習概覽(下)
3.[GAN學習系列] 初識GAN
4.[GAN學習系列2] GAN的起源
5.谷歌開源的 GAN 庫–TFGAN
總結
以上是生活随笔為你收集整理的TensorFlow 加载多个模型的方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 信息安全原理复习资料
- 下一篇: html5 支持音频格式,html5中a