pytorch中的pre-train函数模型引用及修改(增减网络层,修改某层参数等)
版權(quán)聲明:本文為博主原創(chuàng)文章,遵循 CC 4.0 BY-SA 版權(quán)協(xié)議,轉(zhuǎn)載請附上原文出處鏈接和本聲明。
本文鏈接:https://blog.csdn.net/whut_ldz/article/details/78845947
一、pytorch中的pre-train模型
卷積神經(jīng)網(wǎng)絡(luò)的訓(xùn)練是耗時的,很多場合不可能每次都從隨機(jī)初始化參數(shù)開始訓(xùn)練網(wǎng)絡(luò)。
pytorch中自帶幾種常用的深度學(xué)習(xí)網(wǎng)絡(luò)預(yù)訓(xùn)練模型,如VGG、ResNet等。往往為了加快學(xué)習(xí)的進(jìn)度,在訓(xùn)練的初期我們直接加載pre-train模型中預(yù)先訓(xùn)練好的參數(shù),model的加載如下所示:
import torchvision.models as models
?
#resnet
model = models.ResNet(pretrained=True)
model = models.resnet18(pretrained=True)
model = models.resnet34(pretrained=True)
model = models.resnet50(pretrained=True)
?
#vgg
model = models.VGG(pretrained=True)
model = models.vgg11(pretrained=True)
model = models.vgg16(pretrained=True)
model = models.vgg16_bn(pretrained=True)
二、預(yù)訓(xùn)練模型的修改
1.參數(shù)修改
對于簡單的參數(shù)修改,這里以resnet預(yù)訓(xùn)練模型舉例,resnet源代碼在Github點擊打開鏈接。
resnet網(wǎng)絡(luò)最后一層分類層fc是對1000種類型進(jìn)行劃分,對于自己的數(shù)據(jù)集,如果只有9類,修改的代碼如下:
# coding=UTF-8
import torchvision.models as models
?
#調(diào)用模型
model = models.resnet50(pretrained=True)
#提取fc層中固定的參數(shù)
fc_features = model.fc.in_features
#修改類別為9
model.fc = nn.Linear(fc_features, 9)
2.增減卷積層
前一種方法只適用于簡單的參數(shù)修改,有的時候我們往往要修改網(wǎng)絡(luò)中的層次結(jié)構(gòu),這時只能用參數(shù)覆蓋的方法,即自己先定義一個類似的網(wǎng)絡(luò),再將預(yù)訓(xùn)練中的參數(shù)提取到自己的網(wǎng)絡(luò)中來。這里以resnet預(yù)訓(xùn)練模型舉例。
# coding=UTF-8
import torchvision.models as models
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
?
class CNN(nn.Module):
?
? ? def __init__(self, block, layers, num_classes=9):
? ? ? ? self.inplanes = 64
? ? ? ? super(ResNet, self).__init__()
? ? ? ? self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?bias=False)
? ? ? ? self.bn1 = nn.BatchNorm2d(64)
? ? ? ? self.relu = nn.ReLU(inplace=True)
? ? ? ? self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
? ? ? ? self.layer1 = self._make_layer(block, 64, layers[0])
? ? ? ? self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
? ? ? ? self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
? ? ? ? self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
? ? ? ? self.avgpool = nn.AvgPool2d(7, stride=1)
? ? ? ? #新增一個反卷積層
? ? ? ? self.convtranspose1 = nn.ConvTranspose2d(2048, 2048, kernel_size=3, stride=1, padding=1, output_padding=0, groups=1, bias=False, dilation=1)
? ? ? ? #新增一個最大池化層
? ? ? ? self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
? ? ? ? #去掉原來的fc層,新增一個fclass層
? ? ? ? self.fclass = nn.Linear(2048, num_classes)
?
? ? ? ? for m in self.modules():
? ? ? ? ? ? if isinstance(m, nn.Conv2d):
? ? ? ? ? ? ? ? n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
? ? ? ? ? ? ? ? m.weight.data.normal_(0, math.sqrt(2. / n))
? ? ? ? ? ? elif isinstance(m, nn.BatchNorm2d):
? ? ? ? ? ? ? ? m.weight.data.fill_(1)
? ? ? ? ? ? ? ? m.bias.data.zero_()
?
? ? def _make_layer(self, block, planes, blocks, stride=1):
? ? ? ? downsample = None
? ? ? ? if stride != 1 or self.inplanes != planes * block.expansion:
? ? ? ? ? ? downsample = nn.Sequential(
? ? ? ? ? ? ? ? nn.Conv2d(self.inplanes, planes * block.expansion,
? ? ? ? ? ? ? ? ? ? ? ? ? kernel_size=1, stride=stride, bias=False),
? ? ? ? ? ? ? ? nn.BatchNorm2d(planes * block.expansion),
? ? ? ? ? ? )
?
? ? ? ? layers = []
? ? ? ? layers.append(block(self.inplanes, planes, stride, downsample))
? ? ? ? self.inplanes = planes * block.expansion
? ? ? ? for i in range(1, blocks):
? ? ? ? ? ? layers.append(block(self.inplanes, planes))
?
? ? ? ? return nn.Sequential(*layers)
?
? ? def forward(self, x):
? ? ? ? x = self.conv1(x)
? ? ? ? x = self.bn1(x)
? ? ? ? x = self.relu(x)
? ? ? ? x = self.maxpool(x)
?
? ? ? ? x = self.layer1(x)
? ? ? ? x = self.layer2(x)
? ? ? ? x = self.layer3(x)
? ? ? ? x = self.layer4(x)
?
? ? ? ? x = self.avgpool(x)
? ? ? ? #新加層的forward
? ? ? ? x = x.view(x.size(0), -1)
? ? ? ? x = self.convtranspose1(x)
? ? ? ? x = self.maxpool2(x)
? ? ? ? x = x.view(x.size(0), -1)
? ? ? ? x = self.fclass(x)
?
? ? ? ? return x
?
#加載model
resnet50 = models.resnet50(pretrained=True)
cnn = CNN(Bottleneck, [3, 4, 6, 3])
#讀取參數(shù)
pretrained_dict = resnet50.state_dict()
model_dict = cnn.state_dict()
# 將pretrained_dict里不屬于model_dict的鍵剔除掉
pretrained_dict = ?{k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新現(xiàn)有的model_dict
model_dict.update(pretrained_dict)
# 加載我們真正需要的state_dict
cnn.load_state_dict(model_dict)
# print(resnet50)
print(cnn)
以上就是相關(guān)的內(nèi)容,本人剛?cè)腴T的小白一枚,請輕噴~
————————————————
版權(quán)聲明:本文為CSDN博主「whut_ldz」的原創(chuàng)文章,遵循 CC 4.0 BY-SA 版權(quán)協(xié)議,轉(zhuǎn)載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/whut_ldz/article/details/78845947
總結(jié)
以上是生活随笔為你收集整理的pytorch中的pre-train函数模型引用及修改(增减网络层,修改某层参数等)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 笔记:基于DCNN的图像语义分割综述
- 下一篇: pytorch常用函数API简析与汇总—