pytorch笔记 pytorch模型中的parameter与buffer
1 模型的兩種參數(shù)
在 Pytorch 中一種模型保存和加載的方式如下:(具體見pytorch模型的保存與加載_劉文巾的博客-CSDN博客)
#save torch.save(net.state_dict(),PATH)#load model=MyModel(*args,**kwargs) model.load_state_dict(torch.load(PATH)) model.eval模型保存的是 net.state_dict()?的返回對象。
net.state_dict()?的返回對象是一個?OrderDict?,它以鍵值對的形式包含模型中需要保存下來的參數(shù)
上例模型中的參數(shù)就是線性層的 weight 和 bias.
?
模型中需要保存下來的參數(shù)包括兩種:
- 一種是反向傳播需要被optimizer更新的,稱之為 parameter
- 一種是反向傳播不需要被optimizer更新,稱之為 buffer
第一種參數(shù)我們可以通過?model.parameters()?返回;
第二種參數(shù)我們可以通過?model.buffers()?返回。
因為我們的模型保存的是?state_dict?返回的?OrderDict,所以這兩種參數(shù)不僅要滿足是否需要被更新的要求,還需要被保存到OrderDict。
2 Parameter
?
Parameter參數(shù)有兩種創(chuàng)建方式:
像我們前面的nn.Conv1d,nn.Linear,nn.RNN等模型,里面的權(quán)重參數(shù)等會被自動認為是Parameter 參數(shù)
3 buffer
buffer參數(shù)我們需要創(chuàng)建tensor, 然后將tensor通過register_buffer()進行注冊,可以通過model.buffers()?返回,注冊完后參數(shù)也會自動保存到OrderDict中去。
4 為什么要注冊???????
為什么不直接將不需要進行參數(shù)修改的變量作為模型類的成員變量就好了,還要進行注冊?
5 實例說明
import torch class net(torch.nn.Module):def __init__(self):super(net,self).__init__()#創(chuàng)建bufferself.register_buffer('my_buffer',torch.Tensor([1,2,3]))self.a=torch.Tensor([1])self.param1=torch.nn.Parameter(torch.Tensor([1,3,5,7,9]))#方法1 創(chuàng)建的parameterparam2=torch.nn.Parameter(torch.Tensor([2,4,6,8,0]))self.register_parameter('param2',param2)self.l=torch.nn.Linear(1,10)def forward(self,x):passn=net()for i in n.state_dict():print(i,n.state_dict()[i]) print('*'*10) for i in n.parameters():print(i) print('*'*10) for i in n.buffers():print(i) print('*'*10)''' param1 tensor([1., 3., 5., 7., 9.]) param2 tensor([2., 4., 6., 8., 0.]) my_buffer tensor([1., 2., 3.]) l.weight tensor([[-0.1490],[-0.2445],[-0.5296],[-0.3687],[-0.9683],[ 0.3491],[-0.8726],[-0.7213],[ 0.3201],[-0.9994]]) l.bias tensor([ 0.6718, 0.3055, 0.7755, 0.3780, -0.8169, 0.3663, -0.6937, -0.3136,0.6907, 0.8732]) ********** Parameter containing: tensor([1., 3., 5., 7., 9.], requires_grad=True) Parameter containing: tensor([2., 4., 6., 8., 0.], requires_grad=True) Parameter containing: tensor([[-0.1490],[-0.2445],[-0.5296],[-0.3687],[-0.9683],[ 0.3491],[-0.8726],[-0.7213],[ 0.3201],[-0.9994]], requires_grad=True) Parameter containing: tensor([ 0.6718, 0.3055, 0.7755, 0.3780, -0.8169, 0.3663, -0.6937, -0.3136,0.6907, 0.8732], requires_grad=True) ********** tensor([1., 2., 3.]) ********** '''?
總結(jié)
以上是生活随笔為你收集整理的pytorch笔记 pytorch模型中的parameter与buffer的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pytorch学习笔记 torchnn.
- 下一篇: 文巾解题 16. 最接近的三数之和