Pytorch学习 - 保存模型和重新加载
生活随笔
收集整理的這篇文章主要介紹了
Pytorch学习 - 保存模型和重新加载
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
Pytorch學習 - 保存和加載模型
- 1. 3個函數(shù)
- 2. 模型不同后綴名的區(qū)別
- 3. 保存和重載模型
參考資料:
Pytorch官方文檔鏈接
某博客
1. 3個函數(shù)
Pytorch中,torch.nn.Module里面的可學習參數(shù)(weights和biases)都存在model.parameters()中。
2. 模型不同后綴名的區(qū)別
pytorch常見保存模型文件的后綴名有 .pt , .pth,.pkl。其實它們并不是在格式上有區(qū)別,只是后綴不同而已(僅此而已),在用torch.save()函數(shù)保存模型文件時,各人有不同的喜好,有些人喜歡用.pt后綴,有些人喜歡用.pth或.pkl.用相同的torch.save()語句保存出來的模型文件沒有什么不同。
3. 保存和重載模型
保存模型主要有兩種方式:
(1)只保存模型的參數(shù),之后使用時再重新構建一個同樣結構的新模型,然后把保存的參數(shù)導入新模型。(推薦)
(2)將整個模型保存下來,然后直接加載整個模型。(有點耗費內存…)
# 保存 torch.save(**model**,PATH) # 讀取 **model = torch.load(PATH)** # 不需要重構模型結構,直接load即可 model.eval()(3)如果沒有訓練完,仍然需要繼續(xù)訓練,除了model_state_dict需要保存,還需要保存optimizer_state_dict,epoch和loss。
# 保存 torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)# 加載 相比于前面只需要加載load_state_dict,還需要加載optimizer,epoch,loss等參數(shù) model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss']model.eval() # - or - model.train() 《新程序員》:云原生和全面數(shù)字化實踐50位技術專家共同創(chuàng)作,文字、視頻、音頻交互閱讀總結
以上是生活随笔為你收集整理的Pytorch学习 - 保存模型和重新加载的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: zip()和zip(*)的区别与使用
- 下一篇: 数学系列 - 概率论 - 泊松分布和(负