【转载】 Tensorflow如何直接使用预训练模型(vgg16为例)
版權(quán)聲明:本文為博主原創(chuàng)文章,遵循 CC 4.0 BY-SA 版權(quán)協(xié)議,轉(zhuǎn)載請附上原文出處鏈接和本聲明。
本文鏈接:https://blog.csdn.net/weixin_44633882/article/details/89054159
------------------------------------------------------------------------------------------
主流的CNN模型基本都會使用VGG16或者ResNet等網(wǎng)絡(luò)作為預(yù)訓(xùn)練模型,正好有個朋友和我說發(fā)給他一個VGG16的預(yù)訓(xùn)練模型和代碼,我就整理了一下。在這里也分享一下,方便大家直接使用。
系統(tǒng)環(huán)境
Tensorflow-gpu 1.12.0
Python 3.5.2
資料來源
官方slim說明
https://github.com/tensorflow/models/tree/1af55e018eebce03fb61bba9959a04672536107d/research/slim
主頁里直接可以看到所提供的模型列表和下載鏈接。
我們選擇vgg16來做個示范哈,雖然vgg16的準(zhǔn)確率現(xiàn)在已經(jīng)不算高。
拿到vgg_16.ckpt模型文件!
直接貼上代碼
vgg16預(yù)訓(xùn)練模型使用代碼
import os
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
PROJECT_PATH = os.path.dirname(os.path.abspath(os.getcwd()))
# 預(yù)訓(xùn)練模型位置
tf.app.flags.DEFINE_string('pretrained_model_path', os.path.join(PROJECT_PATH, 'data/vgg_16.ckpt'), '')
FLAGS = tf.app.flags.FLAGS
def vgg_arg_scope(weight_decay=0.1):
"""定義 VGG arg scope.
Args:
weight_decay: The l2 regularization coefficient.
Returns:
An arg_scope.
"""
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=slim.l2_regularizer(weight_decay),
biases_initializer=tf.zeros_initializer()):
with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
return arg_sc
def vgg16(inputs,scope='vgg_16'):
with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
# Collect outputs for conv2d, fully_connected and max_pool2d.
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],):
# outputs_collections=end_points_collection):
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [2, 2], scope='pool3')
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
net = slim.max_pool2d(net, [2, 2], scope='pool4')
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
# net = slim.max_pool2d(net, [2, 2], scope='pool5')
# net = slim.fully_connected(net, 4096, scope='fc6')
# net = slim.dropout(net, 0.5, scope='dropout6')
# net = slim.fully_connected(net, 4096, scope='fc7')
# net = slim.dropout(net, 0.5, scope='dropout7')
# net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc8')
return net
def net():
input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
with slim.arg_scope(vgg_arg_scope()):
conv5_3 = vgg16(input_image) # vgg16網(wǎng)絡(luò)
init = tf.global_variables_initializer()
# restore預(yù)訓(xùn)練模型op
if FLAGS.pretrained_model_path is not None:
variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path,
slim.get_trainable_variables(),
ignore_missing_vars=True)
with tf.Session() as sess:
sess.run(init)
if FLAGS.pretrained_model_path is not None:
# resotre 預(yù)訓(xùn)練模型
variable_restore_op(sess)
a = sess.run([conv5_3],feed_dict={input_image:np.arange(360000).reshape(1,300,400,3)})
if __name__ == '__main__':
net()
print(tf.trainable_variables())
講一講,代碼里要注意的地方吧,也比較簡單易懂。
1.vgg_arg_scope
def vgg_arg_scope(weight_decay=0.1):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=slim.l2_regularizer(weight_decay),
biases_initializer=tf.zeros_initializer()):
with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
return arg_sc
vgg_arg_scope()函數(shù)返回了一個scope參數(shù)空間,使用起來就是with slim.arg_scope(vgg_arg_scope()):,
它規(guī)定了[slim.conv2d, slim.fully_connected]都要滿足什么變量參數(shù),比如:激活函數(shù),參數(shù)初始化。
拿activation_fn=tf.nn.relu來說,所有在這個變量空間中的conv2d卷積和fully_connected全連接都是指定了relu作為激活函數(shù)。
當(dāng)然,這里存在覆蓋是可以的,可以嵌套arg_scope進(jìn)行設(shè)置,內(nèi)層空間覆蓋了外層空間,最內(nèi)層的就是slim.conv2d()里傳入指定的參數(shù)了,這是覆蓋了所有外層的。變量空間在我看來,非常方便,也使網(wǎng)絡(luò)定義變得簡單。
2.slim.repeat()
在VGG16中比如一個conv,其中做了3次相同的卷積,寫出來的代碼就很長,使用repeat()就簡單一句話net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')增強(qiáng)了代碼可讀性,而有人可能會問,那三層卷積層怎么進(jìn)行標(biāo)識呢?
當(dāng)然沒問題,你輸出變量會發(fā)現(xiàn)是類似conv5/conv5_1,在_后面遞增自動標(biāo)記區(qū)分。
3.代碼里是每個層是如何拿到自己對應(yīng)的模型參數(shù)呢?
這個應(yīng)該是有些人的困惑吧,畢竟不知道這個,也只能拿著代碼直接用。這個的關(guān)鍵是變量空間。
網(wǎng)絡(luò)定義完成了,你可以通過 print(tf.trainable_variables()) 來獲得所有網(wǎng)絡(luò)中的變量。
我貼出來 vgg16 中的變量,太多了,撿重要的說,就說說 conv1,可以看到變量是這么標(biāo)識的 vgg_16/conv1/conv1_1/weights,前面有很多前綴,就和龍母報出來自己一堆頭銜一樣,其實是起到一個定位效果。
# [<tf.Variable 'vgg_16/conv1/conv1_1/weights:0' shape=(3, 3, 3, 64) dtype=float32_ref>,
# <tf.Variable 'vgg_16/conv1/conv1_1/biases:0' shape=(64,) dtype=float32_ref>,
# <tf.Variable 'vgg_16/conv1/conv1_2/weights:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
# <tf.Variable 'vgg_16/conv1/conv1_2/biases:0' shape=(64,) dtype=float32_ref>,
# <tf.Variable 'vgg_16/conv2/conv2_1/weights:0' shape=(3, 3, 64, 128) dtype=float32_ref>,
在代碼里,我們要讓每個層在預(yù)訓(xùn)練模型里找到自己對應(yīng)的參數(shù),就必須這么定義變量空間。
with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d]):
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
看到了 scope 和 ‘vgg_16’兩個,其實 scope 我們也傳入的是 ’vgg_16’,tf.variable_scope() 的參數(shù),前兩個是 name_or_scope, default_name。默認(rèn)名稱是當(dāng) name_or_scope 為空時,使用的默認(rèn)名稱。
這么整理一下,'vgg_16', 后面的slim.repeat()里的scope='conv1',還有自動標(biāo)記的 conv1_1。
連起來就是 vgg_16/conv1/conv1_1 。
4. 預(yù)訓(xùn)練模型restore。
先準(zhǔn)備op,而且若 pretrained_model_path 不為空,才加入和使用 variable_restore_op
if FLAGS.pretrained_model_path is not None:
variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path,
slim.get_trainable_variables(),
ignore_missing_vars=True)
在Session()中使用
if FLAGS.pretrained_model_path is not None:
variable_restore_op(sess)
講解完畢!哦,還有補充一下,一般vgg16來說,只會拿conv5_3的輸出,繼續(xù)做fine-tune。所以,你只用conv5_3,測試的時候是不用在意輸入圖片的大小的,因為都是卷積嘛。但是,我測試的時候,傳入了個(1,3,6,3)的數(shù)組,出現(xiàn)了這么一個錯。想了想,嗯,應(yīng)該是這個數(shù)組做不了那么多次卷積的,所以Tensorflow報錯了。(這里只是簡單記錄一下),所以用一個大一些的數(shù)組傳入就可以啦
2019-04-06 12:20:14.650154: F tensorflow/stream_executor/cuda/cuda_dnn.cc:542] Check failed: cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd, dims.data(), strides.data()) == CUDNN_STATUS_SUCCESS (3 vs. 0)batch_descriptor: {count: 1 feature_map_count: 128 spatial: 0 1 value_min: 0.000000 value_max: 0.000000 layout: BatchDepthYX}
bash: line 1: 2492 Aborted (core dumped) env "PYTHONUNBUFFERED"="1" "PYTHONPATH"="/tmp/pycharm_project_299:/home/benke/.pycharm_helpers/pycharm_matplotlib_backend" "PYCHARM_HOSTED"="1" "JETBRAINS_REMOTE_RUN"="1" "PYCHARM_MATPLOTLIB_PORT"="65407" "PYTHONIOENCODING"="UTF-8" '/opt/anaconda3/bin/python' '-u' '/tmp/pycharm_project_299/data/vgg.py'
---------------------------------------------------------------------------------------------
轉(zhuǎn)者注:
tensorflow官方預(yù)訓(xùn)練模型下載鏈接:
https://github.com/tensorflow/models/tree/master/research/slim
總結(jié)
以上是生活随笔為你收集整理的【转载】 Tensorflow如何直接使用预训练模型(vgg16为例)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 皮鞋磨脚怎么办
- 下一篇: WebSocket和HTTP的区别与联系