conv2d() received an invalid combination of arguments问题解决
在學習動手學深度學習風格遷移這一部分的時候,程序運行的時候抱錯:conv2d() received an invalid combination of arguments
具體來說,先使用函數SynthesizedImage定義一個圖像,它的權重是更新的目標,經get_inits實例化,通過訓練更新圖像的權重,獲得風格遷移后的圖像。
class SynthesizedImage(nn.Module):def __init__(self, img_shape, **kwargs):super(SynthesizedImage, self).__init__(**kwargs)self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weightdef get_inits(content_img, lr, lr_decay_epoch, init_random):gen_img = SynthesizedImage(content_img.shape).to(device)if not init_random: gen_img.weight.data.copy_(content_img.data)optimizer = torch.optim.Adam(gen_img.parameters(), lr=lr)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_decay_epoch, 0.8)return gen_img(), optimizer, scheduler參考:在python中遇到的錯誤(二):用pytorch的CNN發生的報錯_游魚不知夏的博客-CSDN博客
發現可能是初始化數據出了問題。經過檢查發現函數get_inits返回值寫成了是gen_img,它的格式是:
<class '__main__.SynthesizedImage'>返回的參數應該寫成gen_img(),返回后的格式是:
<class 'torch.nn.parameter.Parameter'>這樣就不會報錯了。
這里蘊含一個知識點:pytorch模型定義。下面舉幾個例子就能明白,為什么gen_img的格式是<class '__main__.SynthesizedImage'>, gen_img()的格式是<class 'torch.nn.parameter.Parameter'>
簡單地說,就是將模型實例化之后,gen_img代表模型自身,gen_img()執行了魔法函數forward(),得到forward()的返回值
第一個例子
class Net1(nn.Module):def __init__(self, img_shape, **kwargs):super(Net1, self).__init__(**kwargs)self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weightmodel = Net1([2,3])>>print(model)
Net1()>>print(model())
Parameter containing: tensor([[0.6031, 0.3673, 0.7362],[0.9071, 0.1086, 0.0191]], requires_grad=True)>>print(type(model))
<class '__main__.Net1'>>>print(type(model()))
<class 'torch.nn.parameter.Parameter'>第二個例子
class Net2(nn.Module):def __init__(self, a):super(Net2, self).__init__()self.conv1 = nn.Conv2d(3, 5, 3)def forward(self, x): return self.conv1model = Net2(1)>>print(model)
Net2((conv1): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1)) )>>print(model(1))
Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1))>>print(type(model))
<class '__main__.Net2'>>>print(type(model(1)))
<class 'torch.nn.modules.conv.Conv2d'>第三個例子
class Net3(nn.Module):def __init__(self, a):super(Net3, self).__init__()self.weight = 123def forward(self): return 456model = Net3(1)>>print(model)
Net3()>>print(model())
456>>print(type(model))
<class '__main__.Net3'>>>print(type(model()))
<class 'int'>總結
以上是生活随笔為你收集整理的conv2d() received an invalid combination of arguments问题解决的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Cocoa动画编程指南
- 下一篇: 现网必用的主备冗余技术,VRRP理论+配