Tensorflow加载模型(进阶版):如何利用预训练模型进行微调(fintuning)
我們要使用別人已經訓練好的模型,就必須將.ckpt文件中的參數加載進來。我們如何有選擇的加載.ckpt文件中的參數呢。首先我們要查看.ckpt都保存了哪些參數:
上代碼:
import tensorflow as tf import os from tensorflow.python import pywrap_tensorflowmodel_dir='./model'#設置模型所在文件夾 checkpoint_path = os.path.join(model_dir, "fineturing_model.ckpt")#定位ckpt文件 # 從checkpoint中讀出數據 reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) # reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法 var_to_shape_map = reader.get_variable_to_shape_map() # 輸出權重tensor名字和值 for key in var_to_shape_map:print("tensor_name: ", key,reader.get_tensor(key).shape)然后我們,照著原來的模型來搞清楚該參數是否應該加載:
接下來我們來看如何有選擇的加載,代碼如下:
import tensorflow as tf import tensorflow.contrib.slim as slim #我們要用到的模塊with tf.Session() as sess:init = tf.global_variables_initializer()sess.run(init)saver1 = tf.train.Saver() #設置默認圖model_name = 'xxxx/xxx/model.ckpt'#saver = tf.train.import_meta_graph('xxx/xxx/model.meta')#variables = tf.contrib.framework.get_variables_to_restore()#有選擇的恢復參數include = ['var_name/wc1','var_name/wc2','var_name/wc3a'.....]variables_to_restore = slim.get_variables_to_restore(include=include)saver = tf.train.Saver(variables_to_restore) print(variables_to_restore) #打印要加載的參數saver.restore(sess,model_name)saver1.save(sess,'./model2/fineturing_model.ckpt')注意:這里應該格外注意saver1 和 saver 的先后順序關系。saver1 = tf.train.Saver()。默認將我們模型中所出現的參數(set1)全都保存,類似于限制默認圖參數。而saver =? tf.train.import_meta_graph()。表示將模型中所出現的參數集(set2)加載進來,可以理解為定義默認圖中就這些參數(set2)。但是set2 真含于?set1,因此如果saver1在saver后定義,當保存某個參數A存在于set1但不存在與set2時,會報錯:Key NotFoundError (see above for traceback):?Variable_xxx not found in checkpoint。我們也可以添加tf.reset_default_graph()來設置默認圖。
到這里我們就知道了如何有選擇的加載預訓練模型來進行遷移學習了。
我們知道,有的模型后綴名是.model。例,c3d 網絡在UCF101上的一個預訓練模型:sports1m_finetuning_ucf101.model。這種應該怎么加載呢?方法其實是一樣的,以sports1m_finetuning_ucf101.model為例:
import tensorflow as tf import tensorflow.contrib.slim as slim #我們要用到的模塊with tf.Session() as sess:init = tf.global_variables_initializer()sess.run(init)saver1 = tf.train.Saver() #設置默認圖model_name = 'xxxx/xxx/sports1m_finetuning_ucf101.model'#有選擇的恢復參數include = ['var_name/wc1','var_name/wc2','var_name/wc3a','var_name/wc3b','var_name/wc4a',"var_name/wc4b","var_name/wc5a",'var_name/wc5b','var_name/bc1','var_name/bc2','var_name/bc3a','var_name/bc3b','var_name/bc4a',"var_name/bc4b","var_name/bc5a",'var_name/bc5b']variables_to_restore = slim.get_variables_to_restore(include=include)saver = tf.train.Saver(variables_to_restore) print(variables_to_restore) #打印要加載的參數saver.restore(sess,model_name)saver1.save(sess,'./model2/fineturing_model.ckpt')相關代碼,注意中文注釋部分:
model_name = "./model/fineturing_model.ckpt"def run_training(batch_size,dropout,epochs):#weight = [weights['wc1'],weights['wc2'],weights['wc3a'],weigths['wc3b'],weights['wc4a'],weights['wc4b'],weights['wc5a'],weights['wc5b']]with tf.Graph().as_default():with tf.variable_scope('var_name') as var_scope: #變量初始化過程weights = {'wc1': _variable_with_weight_decay('wc1', [3, 3, 3, 3, 64], 0.04, 0.00),'wc2': _variable_with_weight_decay('wc2', [3, 3, 3, 64, 128], 0.04, 0.00),'wc3a': _variable_with_weight_decay('wc3a', [3, 3, 3, 128, 256], 0.04, 0.00),'wc3b': _variable_with_weight_decay('wc3b', [3, 3, 3, 256, 256], 0.04, 0.00),'wc4a': _variable_with_weight_decay('wc4a', [3, 3, 3, 256, 512], 0.04, 0.00),'wc4b': _variable_with_weight_decay('wc4b', [3, 3, 3, 512, 512], 0.04, 0.00),'wc5a': _variable_with_weight_decay('wc5a', [3, 3, 3, 512, 512], 0.04, 0.00),'wc5b': _variable_with_weight_decay('wc5b', [3, 3, 3, 512, 512], 0.04, 0.00),'cam':_variable_with_weight_decay('cam', [1,1,512,c3d_model.NUM_CLASSES], 0.04,0.00),}biases = {'bc1': _variable_with_weight_decay('bc1', [64], 0.04, 0.0),'bc2': _variable_with_weight_decay('bc2', [128], 0.04, 0.0),'bc3a': _variable_with_weight_decay('bc3a', [256], 0.04, 0.0),'bc3b': _variable_with_weight_decay('bc3b', [256], 0.04, 0.0),'bc4a': _variable_with_weight_decay('bc4a', [512], 0.04, 0.0),'bc4b': _variable_with_weight_decay('bc4b', [512], 0.04, 0.0),'bc5a': _variable_with_weight_decay('bc5a', [512], 0.04, 0.0),'bc5b': _variable_with_weight_decay('bc5b', [512], 0.04, 0.0),}images_placeholder, labels_placeholder = placeholder_inputs(batch_size)logits, CAM = c3d_model.inference_c3d(images_placeholder[:,:,:,:,:], dropout, batch_size, weights, biases)#導入模型結構loss_ = loss(logits,labels_placeholder)accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels_placeholder), tf.float32))train_op = tf.train.AdamOptimizer(1e-4).minimize(loss_)#softmax_ = soft(logits)print('**********')#reader = pywrap_tensorflow.NewCheckpointReader(model_name)with tf.Session() as sess:init = tf.global_variables_initializer()sess.run(init)saver1 = tf.train.Saver()saver = tf.train.import_meta_graph('./model/fineturing_model.ckpt.meta')#tf.reset_default_graph()saver.restore(sess,model_name) saver1.save(sess,'./model2/fineturing_model.ckpt')for i in range ():......總結
以上是生活随笔為你收集整理的Tensorflow加载模型(进阶版):如何利用预训练模型进行微调(fintuning)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Tensorflow载入模型详解,方法一
- 下一篇: 参数形参错误之 SyntaxError: