(pytorch-深度学习)实现稠密连接网络(DenseNet)
生活随笔
收集整理的這篇文章主要介紹了
(pytorch-深度学习)实现稠密连接网络(DenseNet)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
稠密連接網絡(DenseNet)
ResNet中的跨層連接設計引申出了數個后續工作。稠密連接網絡(DenseNet)與ResNet的主要區別在于在跨層連接上的主要區別:
- ResNet使用相加
- DenseNet使用連結
ResNet(左)與DenseNet(右):
圖中將部分前后相鄰的運算抽象為模塊AAA和模塊BBB。
- DenseNet里模塊BBB的輸出不是像ResNet那樣和模塊AAA的輸出相加,而是在通道維上連結。
- 這樣模塊AAA的輸出可以直接傳入模塊BBB后面的層。在這個設計里,模塊AAA相當于直接跟模塊BBB后面的所有層直接連接在了一起。這也是它被稱為“稠密連接”的原因。
DenseNet的主要構建模塊是稠密塊(dense block)和過渡層(transition layer)。
- 稠密塊定義了輸入和輸出是如何連結的
- 過渡層用來控制通道數,控制其大小
稠密塊
DenseNet使用了ResNet改良版的“批量歸一化、激活和卷積”結構:
import time import torch from torch import nn, optim import torch.nn.functional as Fdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def conv_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))return blk- 稠密塊由多個conv_block組成,每塊使用相同的輸出通道數。
- 在前向計算時,我們將每塊的輸入和輸出在通道維上連結。
定義一個有2個輸出通道數為10的卷積塊。
- 使用通道數為3的輸入時,我們會得到通道數為3+2×10=233+2\times 10=233+2×10=23的輸出。
- 卷積塊的通道數控制了輸出通道數相對于輸入通道數的增長,因此也被稱為增長率(growth rate)。
過渡層
- 每個稠密塊都會帶來通道數的增加,使用過多則會帶來過于復雜的模型。
- 過渡層用來控制模型復雜度。它通過1×11\times11×1卷積層來減小通道數,并使用步幅為2的平均池化層減半高和寬,從而進一步降低模型復雜度。
對上例中稠密塊的輸出,使用通道數為10的過渡層。此時輸出的通道數減為10,高和寬均減半。
blk = transition_block(23, 10) blk(Y).shape # torch.Size([4, 10, 4, 4])DenseNet模型
DenseNet首先使用和ResNet一樣的單卷積層和最大池化層。
net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))- 接著使用4個稠密塊。
- 同ResNet一樣,我們可以設置每個稠密塊使用多少個卷積層(這里設成4)。
- 稠密塊里的卷積層通道數(即增長率)設為32,所以每個稠密塊將增加128個通道。
ResNet里通過步幅為2的殘差塊在每個模塊之間減小高和寬。DenseNet則使用過渡層來減半高和寬,并減半通道數。
num_channels, growth_rate = 64, 32 # num_channels為當前的通道數 num_convs_in_dense_blocks = [4, 4, 4, 4]for i, num_convs in enumerate(num_convs_in_dense_blocks):DB = DenseBlock(num_convs, num_channels, growth_rate)net.add_module("DenseBlosk_%d" % i, DB)# 上一個稠密塊的輸出通道數num_channels = DB.out_channels# 在稠密塊之間加入通道數減半的過渡層if i != len(num_convs_in_dense_blocks) - 1:net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))num_channels = num_channels // 2- 最后接上全局池化層和全連接層來輸出。
- 打印每個子模塊的輸出維度
- 獲取數據
訓練模型
def train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start)) lr, num_epochs = 0.001, 5 optimizer = torch.optim.Adam(net.parameters(), lr=lr) train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)《動手學深度學習》
總結
以上是生活随笔為你收集整理的(pytorch-深度学习)实现稠密连接网络(DenseNet)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 这款堪称完美的PDF编辑器,帮你节省50
- 下一篇: 为何美洲蝉中意17这个质数?