pytorch基础知识整理(三)模型保存与加载
1, torch.save(); troch.load()
torch.save()使用python的pickle模塊把目標(biāo)保存到磁盤,可以用來保存模型、張量、字典等,文件后綴名一般用pth或pt或pkl。torch.load()使用python的pickle模塊實現(xiàn)從磁盤加載。可以用此來直接保存或加載完整模型:
torch.save(model, 'PATH.pth') model = torch.load('PATH.pth')注意:pytorch1.6以后保存的模型使用zip壓縮,所以保存的模型無法被1.6以前的版本加載,如果要跨版本使用,需要做以下修改
torch.save(model, 'PATH.pth', _use_new_zipfile_serialization=False)2, .state_dict(); .load_state_dict()
模型的框架已經(jīng)在程序代碼中了,因此訓(xùn)練好的模型只需要保存模型的參數(shù)即可供推理使用。model.state_dict()以字典的形式保存模型的參數(shù),字典的鍵是參數(shù)名,值是參數(shù)值的張量。得到狀態(tài)字典后還需用torch.save()固化到磁盤。
除模型外,優(yōu)化器optimizer也可以保存和加載狀態(tài)字典。
注意在多卡GPU訓(xùn)練時,保存和加載模型需要在model后加上module,即
torch.save(model.module.state_dict(), 'PATH.pth') model.module.load_state_dict(torch.load('PATH.pth'))3, 保存checkpoint
如果是訓(xùn)練中途保存用于繼續(xù)訓(xùn)練,就不僅要保存權(quán)重參數(shù),還要保存當(dāng)前epoch,優(yōu)化器的狀態(tài),當(dāng)前的損失值等,可以統(tǒng)一打包到一個字典中保存為checkpoint,此時文件后綴名一般用tar。
#保存: torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH) ##加載: checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss']總結(jié)
以上是生活随笔為你收集整理的pytorch基础知识整理(三)模型保存与加载的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pytorch基础知识整理(一)自动求导
- 下一篇: pytorch基础知识整理(四) 模型