神经网络模型模型转ONNX
近期由于業(yè)務(wù)需要,需要將訓(xùn)練好的模型轉(zhuǎn)為ONNX格式,為此頗費(fèi)了一番功夫,在此總結(jié)一下吧。。
1、ONNX是一種神經(jīng)網(wǎng)絡(luò)模型保存的中間格式,支持多種格式的模型轉(zhuǎn)為ONNX,也支持使用ONNX導(dǎo)入多種格式的模型,具體見(jiàn)https://github.com/onnx/tutorials;目前其實(shí)ONNX對(duì)于模型的支持還不是太好,主要表現(xiàn)在一些op還不能夠支持;
2、在PyTorch下要將模型保存成ONNX格式需要使用torch.onnx.export()函數(shù),使用該函數(shù)的時(shí)候需要傳入下面參數(shù):
--model:待保存的model,也就是你在程序中已經(jīng)訓(xùn)練好或者初始化好的模型
--input_shape:指定輸入數(shù)據(jù)的大小,也就是輸入數(shù)據(jù)的形狀,是一個(gè)包含輸入形狀元組的列表;
--name:模型的名稱(chēng),即模型的保存路徑;
--verbrose:True或者False,用來(lái)指定輸出模型時(shí)是否將模型的結(jié)構(gòu)打印出來(lái);
--input_names:輸入數(shù)據(jù)節(jié)點(diǎn)的名稱(chēng),數(shù)據(jù)類(lèi)型為包含字符串的列表;一般將這個(gè)名稱(chēng)設(shè)為['data'];
--output_names:輸出數(shù)據(jù)節(jié)點(diǎn)的名稱(chēng),類(lèi)型與輸入數(shù)據(jù)的節(jié)點(diǎn)名稱(chēng)相同;
在成功導(dǎo)出模型后,可以使用ONNX再對(duì)模型進(jìn)行檢查:
import onnx# Load the ONNX model model = onnx.load("alexnet.onnx")# Check that the IR is well formed onnx.checker.check_model(model)# Print a human readable representation of the graph onnx.helper.printable_graph(model.graph)目前PyTorch還不支持導(dǎo)入ONNX格式的模型。
3、使用MXNET導(dǎo)出模型為ONNX時(shí),參考地址:https://cwiki.apache.org/confluence/display/MXNET/ONNX,http://mxnet.incubator.apache.org/versions/master/tutorials/onnx/export_mxnet_to_onnx.html。MXNet模型的保存格式為.json文件+.params文件,.json文件里保存的是模型的結(jié)構(gòu),.params文件中保存的是模型的參數(shù)。使用onnx_mxnet.export_export_model()方法就可以實(shí)現(xiàn)將模型從mxnet轉(zhuǎn)為ONNX格式,該方法需要傳入的參數(shù)為:
--sym:.json文件,也就是保存了網(wǎng)絡(luò)結(jié)構(gòu)的文件
--params:參數(shù)文件
--input_shape:輸入數(shù)據(jù)的形狀,是一個(gè)包含形狀元組的列表
--input_type:輸入數(shù)據(jù)的類(lèi)型;
--模型的保存路徑
4、從MXNet導(dǎo)入ONNX格式模型:需要使用mxnet.contrib.onnx.onnx2mx.import_model.import_model(model_file),這里返回的是sym, arg_arams,aux_params,也就是網(wǎng)絡(luò)結(jié)構(gòu)symbol對(duì)象,保存參數(shù)的字典, 再將其轉(zhuǎn)為MXNet的module對(duì)象(使用mxnet.module.Module()),即可將模型恢復(fù)到mxnet框架下可執(zhí)行的模型。
?
最后,好久沒(méi)有記錄日常學(xué)習(xí)積累的東西了,趁著失眠開(kāi)個(gè)好頭吧,晚安。。。
轉(zhuǎn)載于:https://www.cnblogs.com/puheng/p/10873289.html
總結(jié)
以上是生活随笔為你收集整理的神经网络模型模型转ONNX的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: dockerq启动报错(iptables
- 下一篇: Linux系统开发: 学习Linux下网