Pytorch nn.Parameter()
? ? ? torch.nn.Parameter是繼承自torch.Tensor的子類,其主要作用是作為nn.Module中的可訓(xùn)練參數(shù)使用。它與torch.Tensor的區(qū)別就是nn.Parameter會(huì)自動(dòng)被認(rèn)為是module的可訓(xùn)練參數(shù),即加入到parameter()這個(gè)迭代器中去;而module中非nn.Parameter()的普通tensor是不在parameter中的。
torch.nn.parameter.Parameter(data=None, requires_grad=True)? ? ? nn.Parameter可以看作是一個(gè)類型轉(zhuǎn)換函數(shù),將一個(gè)不可訓(xùn)練的類型 Tensor 轉(zhuǎn)換成可以訓(xùn)練的類型 parameter ,并將這個(gè) parameter 綁定到這個(gè)module 里面(net.parameter() 中就有這個(gè)綁定的 parameter,所以在參數(shù)優(yōu)化的時(shí)候可以進(jìn)行優(yōu)化),所以經(jīng)過類型轉(zhuǎn)換這個(gè)變量就變成了模型的一部分,成為了模型中根據(jù)訓(xùn)練可以改動(dòng)的參數(shù)。使用這個(gè)函數(shù)的目的也是想讓某些變量在學(xué)習(xí)的過程中不斷的修改其值以達(dá)到最優(yōu)化。
? ? ?nn.Parameter()添加的參數(shù)會(huì)被添加到Parameters列表中,會(huì)被送入優(yōu)化器中隨訓(xùn)練一起學(xué)習(xí)更新 ??
? ? ? 在nn.Module類中,pytorch也是使用nn.Parameter來對(duì)每一個(gè)module的參數(shù)進(jìn)行初始化的
?
但是如果 nn.Parameter(requires_grad=False) 那么這個(gè)參數(shù)雖然綁定到模型里了,但是還是不可訓(xùn)練的,只是為了模型完整性這樣寫(例如magiclayout CVPR2021)
requires_grad默認(rèn)值為True,表示可訓(xùn)練,False表示不可訓(xùn)練。
這樣寫還有一個(gè)好處就是,這個(gè)參數(shù)會(huì)隨著模型的被移到cuda上,即如果執(zhí)行過model.cuda(), 那么這個(gè)參數(shù)也就被移到了cuda上了
舉例
import torch from torch import nnclass MyModule(nn.Module):def __init__(self, input_size, output_size):super(MyModule, self).__init__()self.test = torch.rand(input_size, output_size)self.linear = nn.Linear(input_size, output_size)def forward(self, x):return self.linear(x)model = MyModule(4, 2) print(list(model.named_parameters())) import torch from torch import nnclass MyModule(nn.Module):def __init__(self, input_size, output_size):super(MyModule, self).__init__()self.test = nn.Parameter(torch.rand(input_size, output_size))self.linear = nn.Linear(input_size, output_size)def forward(self, x):return self.linear(x)model = MyModule(4, 2) print(list(model.named_parameters()))也可以在外面,通過register_parameter()注冊(cè)
import torch from torch import nnclass MyModule(nn.Module):def __init__(self, input_size, output_size):super(MyModule, self).__init__()self.linear = nn.Linear(input_size, output_size)def forward(self, x):return self.linear(x)model = MyModule(4, 2) my_test = nn.Parameter(torch.rand(4, 2)) model.register_parameter('test',my_test) print(list(model.named_parameters()))總結(jié)
以上是生活随笔為你收集整理的Pytorch nn.Parameter()的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【概率DP】 ZOJ 3380 Patc
- 下一篇: 从入门到深入!java游戏口袋精灵