yolov4源码_YOLOv4特征提取网络——CSPDarkNet结构解析及PyTorch实现
1 YOLOv4目標檢測模型
自從Redmon說他不在更新YOLO系列之后,我一度以為這么好用的框架就要慢慢淡入歷史了,事實是我多慮了。YOLOv4在使用YOLO Loss的基礎上,使用了新的backbone,并且集成了很多新的優化方法及模型策略,如Mosaic,PANet,CmBN,SAT訓練,CIoU loss,Mish激活函數,label smoothing等等??芍^集SoAT之大成,也實現了很好的檢測精度和速度。 這篇博客主要討論YOLOv4中的backbone——CSP-DarkNet,以及其實現的所必需的Mish激活函數,CSP結構和DarkNet。
開源項目YOLOv5相比YOLOv4有了比較夸張的突破,成為了全方位吊打EfficientDet的存在,其特征提取網絡也是CSP-DarkNet。
1.1 Mish激活函數
激活函數是為了提高網絡的學習能力,提升梯度的傳遞效率。CNN常用的激活函數也在不斷地發展,早期網絡常用的有ReLU,LeakyReLU,softplus等,后來又有了Swish,Mish等。Mish激活函數的計算復雜度比ReLU要高不少,如果你的計算資源不是很夠,可以考慮使用LeakyReLU代替Mish。在介紹之前,需要先了解softplus和tanh函數。
softplus激活函數的公式如下:
上圖是其輸出曲線,softplus和ReLU的曲線具有相似性,但是其比ReLU更為平滑。
目前的普遍看法是,平滑的激活函數允許更好的信息深入神經網絡,從而得到更好的準確性和泛化。
tanh的公式如下:
Mish激活函數的公式為:
上圖為Mish的曲線。首先其和ReLU一樣,都是無正向邊界的,可以避免梯度飽和;其次Mish函數是處處光滑的,并且在絕對值較小的負值區域允許一些負值。
1.2 CSP結構和DarkNet
CSP和DarkNet的結構我在之前的博客中有介紹,如果不清楚的同學,歡迎戳鏈接:CSPNet,DarkNet。
這里為了方便對比,給出DarkNet-53的架構圖:
1.3 CSP-DarkNet
博客【darknet】darknet——CSPDarknet53網絡結構圖(YOLO V4使用)畫出了DarkNet-53的結構圖,畫得很簡明清晰,我借過來用一下:
CSP-DarkNet和CSP-ResNe(X)t的整體思路是差不多的,沿用網絡的濾波器尺寸和整體結構,在每組Residual block加上一個Cross Stage Partial結構。并且,CSP-DarkNet中也取消了Bottleneck的結構,減少了參數使其更容易訓練。
但是,有個地方看圖還是不清楚——CSP輸入的時候通道是什么比例劃分的? 查看了一些源碼,最終確認了結構,在一下部分進行討論。
【討論】
按照CSP論文中的思路,我開始認為的CSP結構應該是這樣的——特征輸入之后,通過一個比例將其分為兩個部分(CSPNet中是二等份),然后再分別輸入block結構,以及后面的Partial transition處理。這樣符合CSPNet論文中的理論思路。
但是實際上,我參考了一些源碼以及darknet配置文件中的網絡參數,得到的結構是這樣的:
和我所理解不同的是,實際的結構在輸入后沒有按照通道劃分成兩個部分,而是直接用兩路的1x1卷積將輸入特征進行變換。 可以理解的是,將全部的輸入特征利用兩路1x1進行transition,比直接劃分通道能夠進一步提高特征的重用性,并且在輸入到resiudal block之前也確實通道減半,減少了計算量。雖然不知道這是否吻合CSP最初始的思想,但是其效果肯定是比我設想的那種情況更好的。性能是王道,我們也按照實際的結構來復現。
2 PyTorch實現CSPDarkNet
這個復現包括了全局池化和全連接層,YOLOv4中使用CSP-DarkNet只使用之前的卷積層用作特征提取。
2.1 Mish激活函數和BN_CONV_Mish結構
class Mish(nn.Module):def __init__(self):super(Mish, self).__init__()def forward(self, x):return x * torch.tanh(F.softplus(x))class BN_Conv_Mish(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bias=False):super(BN_Conv_Mish, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation,groups=groups, bias=bias)self.bn = nn.BatchNorm2d(out_channels)def forward(self, x):out = self.bn(self.conv(x))return Mish()(out)2.2 Basic block
使用的是殘差結構,要注意的是:按照residual block的一貫思路,shortcut之前的最后一層卷積使用線性激活(不適使用激活函數)。
class ResidualBlock(nn.Module):"""basic residual block for CSP-Darknet"""def __init__(self, chnls, inner_chnnls=None):super(ResidualBlock, self).__init__()if inner_chnnls is None:inner_chnnls = chnlsself.conv1 = BN_Conv_Mish(chnls, inner_chnnls, 1, 1, 0) # always use samepaddingself.conv2 = nn.Conv2d(inner_chnnls, chnls, 3, 1, 1, bias=False)self.bn = nn.BatchNorm2d(chnls)def forward(self, x):out = self.conv1(x)out = self.conv2(out)out = self.bn(out) + xreturn Mish()(out)2.3 CSP-DarkNet
按照上圖的結構實現CSP結構并搭建網絡。需要注意的是,第一個CSP結構和后面的有略微差別:
class CSPFirst(nn.Module):"""First CSP Stage"""def __init__(self, in_chnnls, out_chnls):super(CSPFirst, self).__init__()self.dsample = BN_Conv_Mish(in_chnnls, out_chnls, 3, 2, 1) # same paddingself.trans_0 = BN_Conv_Mish(out_chnls, out_chnls, 1, 1, 0)self.trans_1 = BN_Conv_Mish(out_chnls, out_chnls, 1, 1, 0)self.block = ResidualBlock(out_chnls, out_chnls//2)self.trans_cat = BN_Conv_Mish(2*out_chnls, out_chnls, 1, 1, 0)def forward(self, x):x = self.dsample(x)out_0 = self.trans_0(x)out_1 = self.trans_1(x)out_1 = self.block(out_1)out = torch.cat((out_0, out_1), 1)out = self.trans_cat(out)return outclass CSPStem(nn.Module):"""CSP structures including downsampling"""def __init__(self, in_chnls, out_chnls, num_block):super(CSPStem, self).__init__()self.dsample = BN_Conv_Mish(in_chnls, out_chnls, 3, 2, 1)self.trans_0 = BN_Conv_Mish(out_chnls, out_chnls//2, 1, 1, 0)self.trans_1 = BN_Conv_Mish(out_chnls, out_chnls//2, 1, 1, 0)self.blocks = nn.Sequential(*[ResidualBlock(out_chnls//2) for _ in range(num_block)])self.trans_cat = BN_Conv_Mish(out_chnls, out_chnls, 1, 1, 0)def forward(self, x):x = self.dsample(x)out_0 = self.trans_0(x)out_1 = self.trans_1(x)out_1 = self.blocks(out_1)out = torch.cat((out_0, out_1), 1)out = self.trans_cat(out)return outclass CSP_DarkNet(nn.Module):"""CSP-DarkNet"""def __init__(self, num_blocks: object, num_classes=1000) -> object:super(CSP_DarkNet, self).__init__()chnls = [64, 128, 256, 512, 1024]self.conv0 = BN_Conv_Mish(3, 32, 3, 1, 1) # same paddingself.neck = CSPFirst(32, chnls[0])self.body = nn.Sequential(*[CSPStem(chnls[i], chnls[i+1], num_blocks[i]) for i in range(4)])self.global_pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(chnls[4], num_classes)def forward(self, x):out = self.conv0(x)out = self.neck(out)out = self.body(out)out = self.global_pool(out)out = out.view(out.size(0), -1)out = self.fc(out)return F.softmax(out)def csp_darknet_53(num_classes=1000):return CSP_DarkNet([2, 8, 8, 4], num_classes)2.4 測試網絡結構
net = csp_darknet_53()
summary(net, (3, 256, 256))
總結
以上是生活随笔為你收集整理的yolov4源码_YOLOv4特征提取网络——CSPDarkNet结构解析及PyTorch实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 贴字开头的成语有哪些啊?
- 下一篇: 千里啼绿映红的下一句是什么啊?