MxNet 模型转Tensorflow pb模型
生活随笔
收集整理的這篇文章主要介紹了
MxNet 模型转Tensorflow pb模型
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
用mmdnn實(shí)現(xiàn)模型轉(zhuǎn)換
參考鏈接:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af
?
?
會(huì)生成resnet50.json(可視化文件) resnet50.npy(權(quán)重參數(shù)) resnet50.pb(網(wǎng)絡(luò)結(jié)構(gòu))三個(gè)文件。
?
生成tf_resnet50.py文件,可以調(diào)用tf_resnet50.py中的KitModel函數(shù)加載npy權(quán)重參數(shù)重新生成原網(wǎng)絡(luò)框架。
打開tf_resnet.py文件,修改load_weights()中的代碼 (tensorflow=1.14.0報(bào)錯(cuò))
try:weights_dict = np.load(weight_file).item()except:weights_dict = np.load(weight_file, encoding='bytes').item()改為
try:weights_dict = np.load(weight_file, allow_pickle=True).item() except:weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item()?
基于resnet50.npy和tf_resnet50.py文??件,固化參數(shù),生成PB文件:
import tensorflow as tf import tf_resnet50 as tf_fun def netWork():model=tf_fun.KitModel("./resnet50.npy")return model def freeze_graph(output_graph):output_node_names = "output"data,fc1=netWork()fc1=tf.identity(fc1,name="output")graph = tf.get_default_graph() # 獲得默認(rèn)的圖input_graph_def = graph.as_graph_def() # 返回一個(gè)序列化的圖代表當(dāng)前的圖init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定sess=sess,input_graph_def=input_graph_def, # 等於:sess.graph_defoutput_node_names=output_node_names.split(",")) # 如果有多個(gè)輸出節(jié)點(diǎn),以逗號隔開 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型f.write(output_graph_def.SerializeToString()) # 序列化輸出if __name__ == '__main__':freeze_graph("frozen_insightface_r50.pb")print("finish!")?
?
轉(zhuǎn)載于:https://www.cnblogs.com/qiangz/p/11134240.html
總結(jié)
以上是生活随笔為你收集整理的MxNet 模型转Tensorflow pb模型的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 函数的内置属性
- 下一篇: NRF52 UICR寄存器读写