pytorch保存模型pth_Day159:模型的保存与加载
網(wǎng)絡(luò)結(jié)構(gòu)和參數(shù)可以分開(kāi)的保存和加載,因此,pytorch保存模型有兩種方法:
- 注意到,兩者都是用torch.save(obj, dir)實(shí)現(xiàn),這個(gè)函數(shù)的作用是將對(duì)象保存到磁盤(pán)中,它的內(nèi)部是使用Python的pickle實(shí)現(xiàn)
- PyTorch約定使用.pt或.pth后綴命名保存文件
- 兩種方法的區(qū)別其實(shí)就是obj參數(shù)的不同:前者的obj是整個(gè)model對(duì)象,后者的obj是從model對(duì)象里獲取存儲(chǔ)了model參數(shù)的詞典,推薦用第二種,雖然麻煩了一丁點(diǎn),但是比較靈活,有利于實(shí)現(xiàn)預(yù)訓(xùn)練、參數(shù)遷移等操作
一般加載模型是在訓(xùn)練完成后用模型做測(cè)試,這時(shí)候加載模型記得要加上model.eval(),把模型切換到evaluation模式,這時(shí)候會(huì)調(diào)整dropout和bactch的模式。
- 網(wǎng)絡(luò)結(jié)構(gòu)及其參數(shù)的保存與加載:load整個(gè)模型,完成了模型的定義和參數(shù)的加載這兩個(gè)過(guò)程
- 只保存/加載模型參數(shù):需要先創(chuàng)建一個(gè)網(wǎng)絡(luò)模型,然后再load_state_dict()
重點(diǎn)介紹一下這種方法,一般訓(xùn)完一個(gè)模型之后不會(huì)只保存一個(gè)模型的參數(shù),為了方便后續(xù)操作,比如恢復(fù)訓(xùn)練、參數(shù)遷移等,會(huì)保存當(dāng)前狀態(tài)的一個(gè)快照,格式以字典的格式存儲(chǔ),具體信息可以根據(jù)自己的需要,下面列出幾個(gè)方面:
- 模型參數(shù)(不帶模型的結(jié)構(gòu))
- 優(yōu)化器參數(shù)
- loss
- epoch
- args
把這些信息用字典包裝起來(lái),然后保存即可。這種方式保存的只是參數(shù),所以,在加載時(shí)需要先創(chuàng)建好模型,然后再把參數(shù)加載進(jìn)去,如下:
# 獲得保存信息save_data = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'epoch': epoch, 'args': args ...}# 保存torch.save(save_data , path)# 加載參數(shù)model_CKPT = torch.load(path)model = Mymodel()optimizer = Myoptimizer()model.load_state_dict(model_CKPT ['model_state_dict'])optimizer.load_state_dict(model_CKPT ['optimizer_state_dict'])...# 若對(duì)于加載參數(shù),用函數(shù)表示,比如:def load_checkpoint(model, checkpoint_path, optimizer): if checkpoint_path != None: model_CKPT = torch.load(checkpoint_path) model.load_state_dict(model_CKPT['state_dict']) print('loading checkpoint!') optimizer.load_state_dict(model_CKPT['optimizer']) return model, optimizer但是,對(duì)于已經(jīng)保存好的模型參數(shù),我們可能修改了一部分網(wǎng)絡(luò)結(jié)構(gòu),比如加了一些,刪除一些等等,那么需要過(guò)濾這些參數(shù),加載方式如下:
def load_checkpoint(model, checkpoint_path, optimizer, loadOptimizer): if checkpoint_path != 'None': print("loading checkpoint...") model_dict = model.state_dict()# 修改后的模型隨機(jī)初始化的參數(shù) modelCheckpoint = torch.load(checkpoint_path) # 修改前的模型參數(shù) pretrained_dict = modelCheckpoint['model_state_dict'] # 過(guò)濾操作 new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} # 獲取修改后模型所需參數(shù) model_dict.update(new_dict) # 打印出來(lái),更新了多少參數(shù) print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))# 修改后模型加載所需的,已經(jīng)訓(xùn)練好的參數(shù) model.load_state_dict(model_dict) print("loaded finished!") # 如果不需要更新優(yōu)化器那么設(shè)置為false if loadOptimizer == True: optimizer.load_state_dict(modelCheckpoint['optimizer_state_dict']) print('loaded! optimizer') else: print('not loaded optimizer') else: print('No checkpoint_path is included') return model, optimizer參考1:https://blog.csdn.net/MoreAction_/article/details/107967053
參考2:https://zhuanlan.zhihu.com/p/38056115
總結(jié)
以上是生活随笔為你收集整理的pytorch保存模型pth_Day159:模型的保存与加载的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: adxl276怎么添加到proteus中
- 下一篇: spark读取hdfs路径下的数据_到底