Inception代码解读
生活随笔
收集整理的這篇文章主要介紹了
Inception代码解读
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
Inception代碼解讀
目錄
- Inception代碼解讀
- 概述
- Inception網絡結構圖
- inception網絡結構框架
- inception代碼細節分析
概述
inception相比起最開始興起的AlexNet和VGG,做了以下重要改動:
1)改變了“直通”型的網絡結構,將一個大的卷積核做的事情分成了幾個小的卷積核來完成;
2)這樣帶來的另一個好處是可以得到不同尺度的特征,并且對不同尺度大小的特征進行融合,使得提取出來的特征的語義信息更加豐富;
3)引入了1x1的卷積核,1x1的卷積核可以用來方便地改變通道數,以便于不同尺度的特征圖經過通道數變換之后能夠concatenate在一起。
Inception網絡結構圖
1)inceptionv1的樸素版本
2)inceptionv1的加1x1卷積核變換通道數的版本
3)inceptionv2的不同類型的網絡結構
a)用兩個3x3代替5x5的卷積核
b) n x n卷積分解成若干個n x1、1 x 1、1 x n卷積的級聯
c) “展寬”結構的inception
inception網絡結構框架
inception代碼細節分析
from collections import namedtuple import warnings import torch from torch import nn, Tensor import torch.nn.functional as F # from .utils import load_state_dict_from_url from typing import Callable, Any, Optional, Tuple, List from torchsummary import summary __all__ = ['Inception3','inception_v3','InceptionOutputs','_InceptionOutputs'] # 預訓練inception模型的權重 model_urls = {'inception_v3_google':'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',} InceptionOutputs = namedtuple('InceptionOutputs',['logits','aux_logits']) InceptionOutputs.__annotations__ = {'logits','aux_logits'}_InceptionOutputs = InceptionOutputsdef inception_v3(pretrained: bool, progress:bool,**kwargs:Any):if pretrained:if 'transform_input'not in kwargs:kwargs['transform_input'] = Trueif 'aux_logits' in kwargs:original_aux_logits = kwargs['aux_logits ']kwrags['aux_logits '] = Trueelse:original_aux_logits = True# 使用預訓練模型,因此初始化參數init_weights設置為Falsekwargs['init_weights'] = Falsemodel = Inception3(**kwargs)state_dict = load_state_dict_from_url(model_urls['inception_v3_googlenet'],progress = progress)model.load_state_dict(state_dict)if not original_aux_logits:model.aux_logits = Falsemodel.AuxLogits = Nonereturn modelreturn Inception3(**kwargs)class Inception3(nn.Module):def __init__(self,num_classes:1000,aux_logits:True,transform_input:False,inception_blocks:None,init_weights:None):super(Inception3,self).__init__()# inception_blocks的不同類型if inception_blocks is None:inception_blocks = [BasicConv2d,InceptionA,InceptionB,InceptionC,InceptionD,InceptionE,InceptionAux]if init_weights is None:warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of ''torchvision. If you wish to keep the old behavior (which leads to long initialization times'' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)init_weights = Trueassert len(inception_blocks)==7# inception的不同部分conv_block = inception_blocks[0]inception_a = inception_blocks[1]inception_b = inception_blocks[2]inception_c = inception_blocks[3]inception_d = inception_blocks[4]inception_e = inception_blocks[5]inception_aux = inception_blocks[6]self.aux_logits = aux_logitsself.transform_input = transform_input# 不同inception結構有不一樣的卷積核大小self.Conv2d_1a_3x3 = conv_block(3,32,kernel_size = 3, stride = 2)self.Conv2d_2a_3x3 = conv_block(32,32,kernel_size = 3)self.Conv2d_2b_3x3 = conv_block(32,64,kernel_size = 3, padding = 1)self.maxpool1 = nn.MaxPool2d(kernel_size = 3, stride = 2)self.Conv2d_3b_1x1 = conv_block(64,80,kernel_size = 1)self.Conv2d_4a_3x3 = conv_block(80,192,kernel_size = 3)self.maxpool2 = nn.MaxPool2d(kernel_size = 3, stride = 2)self.Mixed_5b = inception_a(192,pool_features = 32)self.Mixed_5c = inception_a(256,pool_features = 64)self.Mixed_5d = inception_a(256,pool_features = 64)self.Mixed_6a = inception_b(288)self.Mixed_6b = inception_c(768,channels_7x7 = 128)self.Mixed_6c = inception_c(768,channels_7x7 = 160)self.Mixed_6d = inception_c(768,channels_7x7 = 160)self.Mixed_6e = inception_c(768,channels_7x7 = 192)self.Auxlogits = Noneself.Mixed_7a = inception_d(768)self.Mixed_7b = inception_e(1280)self.Mixed_7c = inception_2(2048)self.avgpool = nn.AdaptiveAvgPool2d((1,1))self.dropout = nn.Dropout()# 分類器self.fc = nn.Linear(2048, num_classes)# 不同層的參數初始化方法if init_weights:if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear):import scipy.stats as statssteddev = m.stddev if hasattr(m,'stddev') else 0.1X = stats.truncnorm(-2,2,scale = stddev)values = torch.as_tensor(X.rvs(m.weights.numel()),dtype = m.weights.dtype)values = values.view(m.weights.size())with torch.no_grad():m.weight.copy_(values)elif isinstance(m,nn.BatchNorm2d):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)def _transform_input(self,x):# 對輸入圖片增加一維,并作中心化if self.transform_input:x_ch0 = torch.unsqueeze(x[:,0],1)*(0.229/0.5)+(0.485-0.5)/0.5x_ch1 = torch.unsqueeze(x[:,1],1)*(0.224/0.5)+(0.456-0.5)/0.5x_ch2 = torch.unsqueeze(x[:,2],1)*(0.225/0.5)+(0.406-0.5)/0.5return xdef _forward(self,x):# N x 3 x 299 x 299x = self.Conv2d_1a_3x3(x)# N x 32 x 149 x 149x = self.Conv2d_2a_3x3(x)# N x 32 x 147 x 147x = self.Conv2d_2b_3x3(x)# N x 64 x 147 x 147x = self.maxpool1(x)# N x 64 x 73 x 73x = self.Conv2d_3b_1x1(x)# N x 80 x 73 x 73x = self.Conv2d_4a_3x3(x)# N x 192 x 71 x 71x = self.maxpool2(x)# N x 192 x 35 x 35x = self.Mixed_5b(x)# N x 256 x 35 x 35x = self.Mixed_5c(x)# N x 288 x 35 x 35x = self.Mixed_5d(x)# N x 288 x 35 x 35x = self.Mixed_6a(x)# N x 768 x 17 x 17x = self.Mixed_6b(x)# N x 768 x 17 x 17x = self.Mixed_6c(x)# N x 768 x 17 x 17x = self.Mixed_6d(x)# N x 768 x 17 x 17x = self.Mixed_6e(x)# N x 768 x 17 x 17aux: Optional[Tensor] = Noneif self.AuxLogits is not None:if self.training:aux = self.AuxLogits(x)# N x 768 x 17 x 17x = self.Mixed_7a(x)# N x 1280 x 8 x 8x = self.Mixed_7b(x)# N x 2048 x 8 x 8x = self.Mixed_7c(x)# N x 2048 x 8 x 8# Adaptive average poolingx = self.avgpool(x)# N x 2048 x 1 x 1x = self.dropout(x)# N x 2048 x 1 x 1x = torch.flatten(x, 1)# N x 2048x = self.fc(x)# N x 1000 (num_classes)return x, aux# @torch.jit.unuseddef eager_outputs(self,x,aux):if self.training and self.aux_logits:return InceptionOutputs(x,aux)else:return xdef forward(self,x):x = self._transform_input(x)x,aux = self._forward(x)aux_defined = self.training and self.aux_logitsif torch.jit.is_scripting():if not aux_defined:warnings.warn("Scripted Inception3 always returns Inception3 Tuple")return InceptionOutputs(x, aux)else:return self.eager_outputs(x, aux) class InceptionA(nn.Module):def __init__(self,in_channels,pool_features,conv_block = None):super(InceptionA,self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch1x1 = conv_block(in_channels,64,kernel_size = 1)self.branch5x5_1 = conv_block(in_channels,48,kernel_size = 1)self.branch5x5_2 = conv_block(48,64,kernel_size = 5,padding = 2)self.branch3x3dbl_1 = conv_block(in_channels,64,kernel_size = 1)self.branch3x3dbl_2 = conv_block(64,96,kernel_size = 1,padding = 1)self.branch3x3dbl_3 = conv_block(96,96,kernel_size = 1,padding = 1)self.branch_pool = conv_block(in_channels,pool_features,kernel_size = 1)def _forward(self,x):# 根據inceptionA的結構搭建網絡branch1x1 = self.branch1x1(x)branch5x5 = self.branch5x5_1(x)branch5x5 = self.branch5x5_1(branch5x5)branch3x3dbl = self.branch3x3dbl_1(x)branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)branch_pool = F.avg_pool2d(x,kernel_size = 3, stride = 1, padding = 1)branch_pool = self.branch_pool(branch_pool)# 把不同尺度的輸出concatenate在一起,也可以寫成torch.cat((branch1x1,branch5x5,branch3x3dbl,branch_pool),axis = 1)outputs = [branch1x1,branch5x5,branch3x3dbl,branch_pool]return outputsdef forward(self,x):outputs = self._forward(x)return torch.cat(outputs,1) class InceptionB(nn.Module):def __init__(self,in_channels: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionB, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)def _forward(self, x: Tensor) -> List[Tensor]:# 根據inceptionB的結構搭建網絡branch3x3 = self.branch3x3(x)branch3x3dbl = self.branch3x3dbl_1(x)branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)outputs = [branch3x3, branch3x3dbl, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)class InceptionC(nn.Module):def __init__(self,in_channels: int,channels_7x7: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionC, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch1x1 = conv_block(in_channels, 192, kernel_size=1)c7 = channels_7x7self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))self.branch_pool = conv_block(in_channels, 192, kernel_size=1)def _forward(self, x: Tensor) -> List[Tensor]:# 根據inceptionC的結構搭建網絡branch1x1 = self.branch1x1(x)branch7x7 = self.branch7x7_1(x)branch7x7 = self.branch7x7_2(branch7x7)branch7x7 = self.branch7x7_3(branch7x7)branch7x7dbl = self.branch7x7dbl_1(x)branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)branch_pool = self.branch_pool(branch_pool)outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)class InceptionD(nn.Module):def __init__(self,in_channels: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionD, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)def _forward(self, x: Tensor) -> List[Tensor]:# 根據inceptionD的結構搭建網絡branch3x3 = self.branch3x3_1(x)branch3x3 = self.branch3x3_2(branch3x3)branch7x7x3 = self.branch7x7x3_1(x)branch7x7x3 = self.branch7x7x3_2(branch7x7x3)branch7x7x3 = self.branch7x7x3_3(branch7x7x3)branch7x7x3 = self.branch7x7x3_4(branch7x7x3)branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)outputs = [branch3x3, branch7x7x3, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)class InceptionE(nn.Module):def __init__(self,in_channels: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionE, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch1x1 = conv_block(in_channels, 320, kernel_size=1)self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))self.branch_pool = conv_block(in_channels, 192, kernel_size=1)def _forward(self, x: Tensor) -> List[Tensor]:# 根據inceptionE的結構搭建網絡branch1x1 = self.branch1x1(x)branch3x3 = self.branch3x3_1(x)branch3x3 = [self.branch3x3_2a(branch3x3),self.branch3x3_2b(branch3x3),]branch3x3 = torch.cat(branch3x3, 1)branch3x3dbl = self.branch3x3dbl_1(x)branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl),self.branch3x3dbl_3b(branch3x3dbl),]branch3x3dbl = torch.cat(branch3x3dbl, 1)branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)branch_pool = self.branch_pool(branch_pool)outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)# inception的旁路輔助模塊 class InceptionAux(nn.Module):def __init__(self, in_channels,num_classes,conv_block = None):super(InceptionAux,self).__init__()if conv_block is None:conv_block = BasicConv2dself.conv0 = conv_block(in_channels,128,kernel_size = 1)self.conv1 = conv_block(128,768,kernel_size = 5)self.conv1.stddev = 0.01self.fc = nn.Linear(768, num_classes)self.fc.stddev = 0.001def forward(self,x):# N x 768 x 17 x 17x = F.avg_pool2d(x,kernel_size = 5, stride = 3)# N x 128 x 5 x 5x = self.conv0(x)# N x 768 x 1 x 1x = self.conv1(x)# Adaptive average poolingx = F.adaptive_avg_pool2d(x,(1,1))# N x 768 x 1 x 1x = torch.flatten(x,1)# N x 768x = self.fc(x)# N x1000return x class BasicConv2d(nn.Module):def __init__(self,in_channels,out_channels,**kwargs:Any):super(BasicConv2d,self).__init__()self.conv = nn.Conv2d(in_channels, out_channels,bias = False,**kwargs)self.bn = nn.BatchNorm2d(out_channels,eps = 0.001)def forward(self,x):x = self.conv(x)x = self.bn(x)return F.relu(x,inplace = True)總結
以上是生活随笔為你收集整理的Inception代码解读的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: NAR:脑疾病研究的“金牌助手”:Bra
- 下一篇: Nature methods | Ale