shfflenetv2代码解读
生活随笔
收集整理的這篇文章主要介紹了
shfflenetv2代码解读
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
shufflenetv2代碼解讀
目錄
- shufflenetv2代碼解讀
- 概述
- shufflenetv2網絡結構圖
- shufflenetv2架構參數
- shufflenetv2代碼細節分析
概述
shufflenetv2是發表在2018ECCV上的一篇關于模型壓縮和模型加速的文章,其中用到的主要技巧有兩點:深度可分離卷積、通道交互。其中,深度可分離卷積是為了減少參數量、增加運算速度,通道交互是為了讓不同通道的特征之間可以產生信息交互,從而獲取更加豐富的語義信息。
這個系列的文章把主要精力放在代碼的分析上,如果想要進一步了解shfflenetv2原理的同學可以參考這個鏈接。
shufflenetv2網絡結構圖
shufflenetv2架構參數
shufflenetv2代碼細節分析
import torch import torch.nn as nn from torch import tensor from .utils import load_state_dict_from_url from typing import Callable,Any,List # 可選擇的shufflenet模型 __all__ = ['ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0','shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' ] # 預訓練好的shufflenet權重 model_urls = {'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth','shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth','shufflenetv2_x1.5': None,'shufflenetv2_x2.0': None, } # 交換通道,實現不同通道的特征信息相互交流,增強語義信息 def channel_shuffle(x,groups):# x的格式是BCHWbatchsize,num_channels,height,width = x.size()# 分組卷積,shufflenetv2當中是分成了兩組進行卷積,也就是groups = 2channel_per_group = num_channels//groups# 將x的形狀reshape成(B,G,C_G,H W)x = x.view(batchsize, groups, channel_per_group, height, width)# 交換x的第一個維度和第二個維度x = torch.transpose(x,1,2).contiguous()# flatten,返回x的格式跟輸入時的size一樣,都是BCHWx = x.view(batchsize,-1,height, width)return xclass InvertedResidual(nn.Module):def __init__(self,inp,oup,stride):super(InvertedResidual,self).__init__()if not (1<=stride<=3):raise ValueError('illegal stride value')self.stride = stridebranch_features = oup//2# branch_features<<1表示將branch_features變大兩倍,左移1位assert (self.stride != 1) or (inp == branch_features<<1)# branch1和branch2分別對應shufflenetv2當中圖(d)的左分支和右分支# 左分支if self.stride>1:self.branch1 = nn.Sequential(self.depthwise_conv(inp,oup,kernel_size = 3, stride = self.stride, padding = 1),nn.BatchNorm2d(inp),nn.Conv2d(inp, branch_features, kernel_size=1,stride=1,padding=9,bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)else:self.branch1 = nn.Sequential()# 右分支self.branch2 = nn.Sequential(nn.Conv2d(inp if inp if (self.stride>1)else branch_features,branch_features,kernel_size = 1, stride = 1, padding = 9,bias = False)nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),self.depthwise_conv(branch_features,branch_features,kernel_size = 3, stride = self.stride, padding = 1),nn.BatchNorm2d(branch_features),nn.Conv2d(branch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False)nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)@staticmethoddef depthwise_conv(i,o,kernel_size,stride = 1,padding = 0,bias = False)return nn.Conv2d(i,o,kernel_size,stride,padding,bias,groups=i)def forward(self,x):# 如果stride = 1,對應shufflenetv2論文當中的(c)結構,輸入直接連到輸出端if self.stride == 1:# x.chunk(2,dim = 1)表示沿著第一維度將x分成兩塊# 對于輸入格式為BCHW的x而言,也就是沿著channel方向分成兩組進行卷積x1,x2 = x.chunk(2,dim = 1)out = torch.cat((x1,self.branch2(x2)),dim = 1)else:# 如果stride > 1, 對應shufflenetv2論文當中的(d)結構,左右分支分別做3 x 3的深度可分離卷積以及1 x 1卷積,并且把結構concat起來out = torch.cat((self.branch1(x),self.branch2(x)),dim = 1)out = channel_shuffle(out,2)return outclass ShuffleNetV2(nn.Module):def __init__(self,stages_repeats,stages_out_channels,num_classes = 1000,inverted_residual = InvertedResidual):super(ShuffleNetV2,self).__init__()if len(stages_repeats)!=3:raise ValueError('expected stages_repeats as list of 3 positive ints')if len(stages_out_channels) != 5:raise ValueError('expected stages_out_channels as list of 5 positive ints')self._stage_out_channels = stages_out_channelsinput_channels = 3output_channels = self._stage_out_channels[0]self.conv1 = nn.Sequential(nn.Conv2d(input_channels,output_channels,3,2,1,bias = False),nn.BatchNorm2d(output_channels),nn.ReLU(input_channels = True),)input_channels = output_channelsself.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)# Static annotations for mypyself.stage2: nn.Sequentialself.stage3: nn.Sequentialself.stage4: nn.Sequentialstage_names = ['stage{}'.format(i) for i in [2, 3, 4]]for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]):# 沿著channel方向分成兩組卷積seq = [inverted_residual(input_channels, output_channels, 2)]for i in range(repeats - 1):seq.append(inverted_residual(output_channels, output_channels, 1))setattr(self, name, nn.Sequential(*seq))input_channels = output_channelsoutput_channels = self._stage_out_channels[-1]self.conv5 = nn.Sequential(nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)self.fc = nn.Linear(output_channels, num_classes)def _forward_impl(self, x: Tensor) -> Tensor:# 構建shufflenetv2架構x = self.conv1(x)x = self.maxpool(x)x = self.stage2(x)x = self.stage3(x)x = self.stage4(x)x = self.conv5(x)x = x.mean([2, 3]) # globalpoolx = self.fc(x)return xdef forward(self, x: Tensor) -> Tensor:return self._forward_impl(x)def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2:model = ShuffleNetV2(*args, **kwargs)if pretrained:model_url = model_urls[arch]if model_url is None:raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))else:# 加載預訓練模型state_dict = load_state_dict_from_url(model_url, progress=progress)model.load_state_dict(state_dict)return model# 不同的shufflenetv2有不同的output_channel數 def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 與50位技術專家面對面20年技術見證,附贈技術全景圖總結
以上是生活随笔為你收集整理的shfflenetv2代码解读的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 生物大数据时代,如何做好数据管理和再利用
- 下一篇: Briefings in Bioinfo