nn.Sequential与nn.ModuleList
1、nn.Sequential
模塊按照順序進(jìn)行排列的,確保輸入與輸出模塊的通道數(shù)相同(實(shí)際上是feature map數(shù)量)。
nn.Sequential寫法有3種:
第一種寫法:
創(chuàng)建nn.Sequential()對象并用add_module方法添加
self.con2=nn.Sequential()self.con2.add_module('conv',nn.Conv2d(16,32,3,1,1))self.con2.add_module('relu',nn.ReLU())self.con2.add_module('bn',nn.BatchNorm2d(32))self.con2.add_module('pool',nn.MaxPool2d(2,2))第二種寫法:
nn.Sequential(*多個(gè)層class的實(shí)例)
self.con1=nn.Sequential(nn.Conv2d(3,16,3,1,1),nn.ReLU(),nn.BatchNorm2d(16),nn.MaxPool2d(2,2))第三種寫法:
nn.Sequential(OrderedDict([*多個(gè)(層名,層class的實(shí)例)]))
self.con3=nn.Sequential(OrderedDict([('con',nn.Conv2d(32,64,3,1,1)),('bn',nn.BatchNorm2d(64)),('relu',nn.ReLU()),('pool',nn.MaxPool2d(2,2))]))2、nn.ModuleList
nn.ModuleList僅僅類似于pytho中的list類型,只是將一系列層裝入列表,并沒有實(shí)現(xiàn)forward()方法,因此也不會(huì)有網(wǎng)絡(luò)模型產(chǎn)生的副作用。nn.ModuleList接受的必須是subModule類型,即不管ModuleList包裹了多少個(gè)列表,內(nèi)嵌的所有列表的內(nèi)部都要是可迭代的Module的子類。參考
#列表需要遍歷self.liner=nn.ModuleList([nn.Linear(64*56*56,1000),nn.Linear(1000,1000),nn.Linear(1000,10)])for liner in self.liner:ly4=liner(ly4)3、nn.Sequential與nn.ModuleList的區(qū)別
不同點(diǎn)1:
nn.Sequential內(nèi)部實(shí)現(xiàn)了forward函數(shù),因此可以不用寫forward函數(shù)。而nn.ModuleList則沒有實(shí)現(xiàn)內(nèi)部forward函數(shù)。
不同點(diǎn)2:
nn.Sequential可以使用OrderedDict對每層進(jìn)行命名
不同點(diǎn)3:
nn.Sequential里面的模塊按照順序進(jìn)行排列的,所以必須確保前一個(gè)模塊的輸出大小和下一個(gè)模塊的輸入大小是一致的。而nn.ModuleList 并沒有定義一個(gè)網(wǎng)絡(luò),它只是將不同的模塊儲存在一起,這些模塊之間并沒有什么先后順序可言。
不同點(diǎn)4:
有的時(shí)候網(wǎng)絡(luò)中有很多相似或者重復(fù)的層,我們一般會(huì)考慮用 for 循環(huán)來創(chuàng)建它們,而不是一行一行地寫
查看Pytorch網(wǎng)絡(luò)的各層輸出(feature map)、權(quán)重(weight)、偏置(bias)
綜上代碼如下:
import torch.nn as nn import torch from collections import OrderedDict import cv2 import matplotlib.pyplot as plt import numpy as npclass mytest(nn.Module):def __init__(self):super(mytest, self).__init__()#第一種寫法self.con1=nn.Sequential(nn.Conv2d(3,16,3,1,1),nn.ReLU(),nn.BatchNorm2d(16),nn.MaxPool2d(2,2))# 第二種寫法self.con2=nn.Sequential()self.con2.add_module('conv',nn.Conv2d(16,32,3,1,1))self.con2.add_module('relu',nn.ReLU())self.con2.add_module('bn',nn.BatchNorm2d(32))self.con2.add_module('pool',nn.MaxPool2d(2,2))# 第三種寫法self.con3=nn.Sequential(OrderedDict([('con',nn.Conv2d(32,64,3,1,1)),('bn',nn.BatchNorm2d(64)),('relu',nn.ReLU()),('pool',nn.MaxPool2d(2,2))]))#列表需要遍歷self.liner=nn.ModuleList([nn.Linear(64*56*56,1000),nn.Linear(1000,1000),nn.Linear(1000,10)])# self.liner = nn.Linear(64 * 56 * 56, 1000)def forward(self,x):ly1=self.con1(x)ly2=self.con2(ly1)ly3=self.con3(ly2)ly4=ly3.view(ly3.size()[0],-1)# x=self.liner[0](x)for liner in self.liner:ly4=liner(ly4)return ly1,ly2,ly3,ly4#返回各層結(jié)果x=torch.randn(2,3,448,448) model=mytest() _,_,_,pred=model(x) print(pred.shape) print(model.con1[0]) print(model.con2.pool) print(model.con3[0]) print(model.con3.relu) print(model.liner[0])#https://blog.csdn.net/xwmwanjy666/article/details/100927858 img0=cv2.imread('D:/data/testpic/3.jpg') print(img0.shape) img=cv2.resize(img0,(448,448)) input=torch.tensor(img).float()#必須轉(zhuǎn)為float類型的數(shù)據(jù) input=input.permute(2,0,1) input=input.unsqueeze(0)#相當(dāng)于增加一個(gè)batch維度 print(input.shape)ly1,ly2,ly3,ly4=model(input)#各層的shape print(ly1.shape) print(ly2.shape) print(ly3.shape) print(ly4.shape)#各層部分特征圖 plt.subplot(3,2,1) img1_1=ly1.squeeze(0).permute(1,2,0).detach().numpy() img1_1=img1_1[:,:,:3]#取3個(gè)通道顯示 # img1_1=img1_1[:,:,0] # img = np.ascontiguousarray(img, dtype=np.float32) plt.title("conv1:feature map1") plt.imshow(img1_1)plt.subplot(3,2,2) img1_2=ly1.squeeze(0).permute(1,2,0).detach().numpy() img1_2=img1_2[:,:,3:6] plt.title("conv1:feature map2") plt.imshow(img1_2)plt.subplot(3,2,3) img2_1=ly2.squeeze(0).permute(1,2,0).detach().numpy() img2_1=img2_1[:,:,:3] plt.title("conv2:feature map1") plt.imshow(img2_1)plt.subplot(3,2,4) img2_2=ly2.squeeze(0).permute(1,2,0).detach().numpy() img2_2=img2_2[:,:,3:6] plt.title("conv2:feature map2") plt.imshow(img2_2)plt.subplot(3,2,5) img3_1=ly3.squeeze(0).permute(1,2,0).detach().numpy() img3_1=img3_1[:,:,:3] plt.title("conv3:feature map1") plt.imshow(img3_1)plt.subplot(3,2,6) img3_2=ly3.squeeze(0).permute(1,2,0).detach().numpy() img3_2=img3_2[:,:,3:6] plt.title("conv3:feature map2") plt.imshow(img3_2)# # plt.imshow(img0)plt.show()? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 各卷積層特征圖
?
參考:
https://blog.csdn.net/e01528/article/details/84397174
https://zhuanlan.zhihu.com/p/75206669
?
總結(jié)
以上是生活随笔為你收集整理的nn.Sequential与nn.ModuleList的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 解读游戏“仙股”飞鱼科技年内涨幅超400
- 下一篇: Educoder -- Web程序设计基