PyTorch框架学习九——网络模型的构建
PyTorch框架學習九——網絡模型的構建
- 一、概述
- 二、nn.Module
- 三、模型容器Container
- 1.nn.Sequential
- 2.nn.ModuleList
- 3.nn.ModuleDict()
- 4.總結
筆記二到八主要介紹與數據有關的內容,這次筆記將開始介紹網絡模型有關的內容,首先我們不追求網絡內部各層的具體內容,重點關注模型的構建,學會了如何構建模型,然后再開始一些具體網絡層的學習。
一、概述
模型有關的內容主要如下圖所示:
主要是模型的搭建和權值的初始化兩個問題,而模型的搭建里,首先需要構建單獨的網絡層,然后將這些網絡層按順序拼接起來,就構成了一個模型,然后進行某種權值初始化,就可以用于訓練數據。
今天介紹PyTorch中是如何實現模型創建的,具體內部的卷積、池化、激活函數等知識下次筆記介紹。上述的所有內容,在PyTorch中都有一個叫nn.Module的模塊來實現。
看一個LeNet模型的例子:
從上圖可以看出LeNet模型經過了這樣一個網絡層的流程:
那我們要來搭建這個模型的話,就要先單獨構建卷積層Conv,池化層pool,全連接層fc,然后按照上面的順序進行拼接,拼接后的整體才是一個構建好的網絡模型。
看一下LeNet的模型構建的代碼:
class LeNet(nn.Module):def __init__(self, classes):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, classes)def forward(self, x):out = F.relu(self.conv1(x))out = F.max_pool2d(out, 2)out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out可以看出__init__()函數實現了對每一個單獨的網絡層的構建,forward()函數實現了子網絡層的拼接。
二、nn.Module
介紹nn.Module之前先看一下torch.nn里四個重要的模塊:
這里重點介紹nn.Parameter和nn.Module。
nn.Module來構建網絡層時會創建8個字典管理它的不同屬性,分別如下所示:
- parameters:存儲管理nn.Parameter類。
- modules:存儲管理nn.Module類。
- buffers:存儲管理緩沖屬性,如BN層中的running_mean。
- ×××_hooks(5個):存儲管理鉤子函數(目前不了解)。
下面的代碼是創建一個module時對8個字典的初始化:
def __init__(self):"""Initializes internal Module state, shared by both nn.Module and ScriptModule."""torch._C._log_api_usage_once("python.nn_module")self.training = Trueself._parameters = OrderedDict()self._buffers = OrderedDict()self._backward_hooks = OrderedDict()self._forward_hooks = OrderedDict()self._forward_pre_hooks = OrderedDict()self._state_dict_hooks = OrderedDict()self._load_state_dict_pre_hooks = OrderedDict()self._modules = OrderedDict()注意:
三、模型容器Container
模型容器有三種,如下圖所示:
1.nn.Sequential
功能:是nn.Module的容器,用于按順序包裝一組網絡層。
還是以LeNet為例,我們將LeNet分成features和classifier兩部分,每個部分都是一個sequential:
代碼如下:
但是,這種構建網絡的方式有一個小問題,每一層網絡層都會自動按順序編一個號作為name,如features這個Sequential里每層網絡層在module屬性內部是這樣的:
這里只有六個網絡層,所以還可以在短時間內找到你需要的那一個,但是當層數非常多的時候,這種數字命名的方式就很不友好,而Sequential也有相應的應對方法,即為每一層網絡命名,具體代碼如下所示:
與原來不同的地方就是,構建了一個OrderedDict字典來存放鍵值對,key就是每一層網絡的名字,value就是具體的網絡層實現,看一下此時的module屬性內部:
這樣就很好尋找所需要的某一層網絡。
綜上,Sequential的特點:
2.nn.ModuleList
也是nn.module的容器,用于包裝一組網絡層,以迭代方式調用網絡層。
主要方法:
這種容器比較適合構建大量重復的網絡層,因為利用了迭代的方法,下面就是構建20個線性層的例子
class ModuleList(nn.Module):def __init__(self):super(ModuleList, self).__init__()self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])def forward(self, x):for i, linear in enumerate(self.linears):x = linear(x)return x3.nn.ModuleDict()
也是nn.module的容器,用于包裝一組網絡層,以索引方式調用網絡層。
主要方法:
這種容器的特點是,因為鍵值對可以索引的特性,可用于選擇網絡層:
class ModuleDict(nn.Module):def __init__(self):super(ModuleDict, self).__init__()self.choices = nn.ModuleDict({'conv': nn.Conv2d(10, 10, 3),'pool': nn.MaxPool2d(3)})self.activations = nn.ModuleDict({'relu': nn.ReLU(),'prelu': nn.PReLU()})def forward(self, x, choice, act):x = self.choices[choice](x)x = self.activations[act](x)return xnet = ModuleDict()fake_img = torch.randn((4, 10, 32, 32))output = net(fake_img, 'conv', 'relu')print(output)我們構建了conv、pool以及relu、prelu,然后我們選擇使用conv和relu。
4.總結
對于上述提及的三種容器,它們各自的特點以及適用范圍如下所示:
總結
以上是生活随笔為你收集整理的PyTorch框架学习九——网络模型的构建的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 粒子群优化算法(Particle Swa
- 下一篇: 集成方法Ensemble Method(