pytorch模型的保存与加载
生活随笔
收集整理的這篇文章主要介紹了
pytorch模型的保存与加载
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
我們先創建一個模型,使用的是pytorch筆記——簡易回歸問題_劉文巾的博客-CSDN博客?的主體框架,唯一不同的是,我這里用的是torch.nn.Sequential來定義模型框架,而不是那篇博客里面的類。
1 保存與加載之前的部分
#導入庫 import torch#數據集 x=torch.linspace(-1,1,100).reshape(-1,1) y=x*x+0.2*torch.rand(x.shape)#定義模型(Sequential比類簡明了很多) net=torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(),torch.nn.Linear(10,1))#設置優化函數與損失函數 optimizer=torch.optim.SGD(net.parameters(),lr=0.2)loss_func=torch.nn.MSELoss()#進行訓練 for epoch in range(1000):prediction=net(x)loss=loss_func(prediction,y)optimizer.zero_grad()#清空上一輪參數優化的參與梯度loss.backward()#損失函數反向傳播optimizer.step()#梯度更新#打印模型里面的參數 for a,b in enumerate(net.parameters()):print('no:',a,'\n',b) ''' 一共有四個參數,分別對應的是每一層的w和b no: 0 Parameter containing: tensor([[ 1.0588],[ 0.1654],[ 0.8578],[-0.8756],[-1.0935],[ 0.7588],[ 0.9043],[-0.0723],[-0.4335],[ 0.3010]], requires_grad=True) no: 1 Parameter containing: tensor([-0.5888, 0.9550, -0.1572, -0.2610, -0.4367, 0.3084, -0.3802, -0.3834,0.6192, 0.3012], requires_grad=True) no: 2 Parameter containing: tensor([[ 0.3397, 0.0161, 0.7238, 0.6869, 0.5263, 0.1717, 0.6978, 0.1012,0.3311, -0.2264]], requires_grad=True) no: 3 Parameter containing: tensor([-0.0836], requires_grad=True) '''2 存儲與加載(方法1)——直接保存模型
我們使用torch.save直接保存模型
torch.save(net,'net.pkl')加載模型的時候,直接torch.load即可(可以看到net2參數和net是一樣的)
net2=torch.load('net.pkl') for a,b in enumerate(net2.parameters()):print('no:',a,'\n',b)''' no: 0 Parameter containing: tensor([[ 1.0588],[ 0.1654],[ 0.8578],[-0.8756],[-1.0935],[ 0.7588],[ 0.9043],[-0.0723],[-0.4335],[ 0.3010]], requires_grad=True) no: 1 Parameter containing: tensor([-0.5888, 0.9550, -0.1572, -0.2610, -0.4367, 0.3084, -0.3802, -0.3834,0.6192, 0.3012], requires_grad=True) no: 2 Parameter containing: tensor([[ 0.3397, 0.0161, 0.7238, 0.6869, 0.5263, 0.1717, 0.6978, 0.1012,0.3311, -0.2264]], requires_grad=True) no: 3 Parameter containing: tensor([-0.0836], requires_grad=True) '''3 存儲和加載(方法2)——保存模型參數
保存的話存模型的參數
torch.save(net.state_dict(),'net_params.pkl')加載的話,我們得先重新聲明一個新的神經網絡結構(用Sequential和用類都可以,有了這個新的神經網絡后,才可以把參數傳進去)【因為在聲明新的神經網絡之前,我們現在存的內容即使加載出來了,也不知道這些參數對應的結構是什么】
#聲明一個新的net net3=torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(),torch.nn.Linear(10,1)) #加載數據 net3.load_state_dict(torch.load('net_params.pkl') )for a,b in enumerate(net3.parameters()):print('no:',a,'\n',b) ''' 和net也是一樣的 no: 0 Parameter containing: tensor([[ 1.0588],[ 0.1654],[ 0.8578],[-0.8756],[-1.0935],[ 0.7588],[ 0.9043],[-0.0723],[-0.4335],[ 0.3010]], requires_grad=True) no: 1 Parameter containing: tensor([-0.5888, 0.9550, -0.1572, -0.2610, -0.4367, 0.3084, -0.3802, -0.3834,0.6192, 0.3012], requires_grad=True) no: 2 Parameter containing: tensor([[ 0.3397, 0.0161, 0.7238, 0.6869, 0.5263, 0.1717, 0.6978, 0.1012,0.3311, -0.2264]], requires_grad=True) no: 3 Parameter containing: tensor([-0.0836], requires_grad=True) '''4 兩種方法的比較
存參數的文件占用的空間少一點,這個在目前這種比較簡單的模型可能還看不出來。對于那種大的模型,省下來的空間還是蠻多的。
總結
以上是生活随笔為你收集整理的pytorch模型的保存与加载的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pytorch笔记:pytorch的乘法
- 下一篇: 文巾解题 203. 移除链表元素