U2Net——U-Net套U-Net——套娃式图像分割算法
U2Net
- 1 相關參考
- 2 U2?NetU^2-NetU2?Net 網絡結構
- 3 網絡代碼和測試
1 相關參考
論文名稱: U2-Net: Goging Deeper with Nested U-Structure for Salient Object Detetion
論文地址: https://arxiv.org/abs/2005.09007
官方源碼: https://github.com/xuebinqin/U-2-Net
參考代碼: Pytorch UNet
參考博客: https://blog.csdn.net/qq_37541097/article/details/126255483
參考視頻: bilibili 我為霹導舉大旗
建議大家可以先看霹導的原理講解視頻和代碼講解視頻,代碼寫的真的太優雅了,以下內容作為自己對重點的記錄和一些代碼中的修改!
2 U2?NetU^2-NetU2?Net 網絡結構
整體結構:
保留了原始的U-Net網絡結構,只是將每一個Block的內部結構做了很大的調整,換成了一個U-Net,同時針對整個結構的輸出做出調整,在訓練時,給六個輸出進行loss計算,在測試時只得到一個輸出。
Block結構RSU:
這里Block,除了輸入和輸出的通道會發生變化,在中間層進行卷積時,使用的通道數都是Mid_channels,同時在最下層的卷積中,使用的是膨脹卷積。 這里的L=7,指的是RSU-7,是En_1和Dn_1的內部結構,在前四層中,都是使用的是RSU結構;
在后面的兩層中,使用的是RSU-4F,其中的卷積層使用的是膨脹卷積,避免因為深度太深,導致圖像尺寸太小,丟失特征,RSU-4F結構如下:
RSU-4F:
這里向下使用了兩層的膨脹卷積,進行特征恢復,避免因為網絡深度太深,導致特征丟失的問題!
損失函數:
網絡在訓練的時候,是對六個輸出分別和GT進行BCE(二值交叉熵)計算,然后對損失求和進行反向傳播,公式如下:
L=∑m=1Mwside?(m)lside?(m)+wfuse?lfuse?L=\sum_{m=1}^{M} w_{\text {side }}^{(m)} l_{\text {side }}^{(m)}+w_{\text {fuse }} l_{\text {fuse }} L=m=1∑M?wside?(m)?lside?(m)?+wfuse??lfuse??
在本網絡中,前面一部分是六個輸出和GT的損失,第二部分是最后的融合圖像和GT的損失,代碼如下:
import torch import torch.nn as nn from torch.nn import functional as F class U2criterion(nn.Module):def __init__(self):super(U2criterion, self).__init__()def forward(self, inputs, target):losses = [F.binary_cross_entropy_with_logits(inputs[i], target) for i in range(len(inputs))]total_loss = sum(losses)return total_loss3 網絡代碼和測試
from typing import Union, List import torch import torch.nn as nn import torch.nn.functional as Fclass ConvBNReLU(nn.Module):def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1):super().__init__()padding = kernel_size // 2 if dilation == 1 else dilation # 保持圖像大小不變self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False), # 因為后面有BN,bias不起作用nn.BatchNorm2d(out_ch),nn.ReLU(inplace=True) )def forward(self, x):return self.conv(x)class DownConvBNReLu(ConvBNReLU):def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, flag=True):super().__init__(in_ch, out_ch, kernel_size, dilation)self.down_flag = flagdef forward(self, x):if self.down_flag:x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True)return self.conv(x)class UpConvBNReLU(ConvBNReLU):def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, flag=True):super().__init__(in_ch, out_ch, kernel_size, dilation)self.up_flag = flagdef forward(self, x1, x2): # x1為下面傳入的, x2為左邊傳入的if self.up_flag:x1 = F.interpolate(x1, size=x2.shape[2:], mode="bilinear", align_corners=False)x = torch.cat([x1, x2], dim=1)return self.conv(x)class RSU(nn.Module):def __init__(self, height, in_ch, mid_ch, out_ch):super().__init__()assert height >= 2self.conv_in = ConvBNReLU(in_ch, out_ch) # 這個是不算在height上的encode_list = [DownConvBNReLu(out_ch, mid_ch, flag=False)]decode_list = [UpConvBNReLU(mid_ch*2, mid_ch, flag=False)]for i in range(height-2): # 含有上下采樣的模塊encode_list.append(DownConvBNReLu(mid_ch, mid_ch))decode_list.append(UpConvBNReLU(mid_ch*2, mid_ch if i < height-3 else out_ch)) # 這里最后的decode的輸出是out_chencode_list.append(ConvBNReLU(mid_ch, mid_ch, dilation=2))self.encode_modules = nn.ModuleList(encode_list)self.decode_modules = nn.ModuleList(decode_list)def forward(self, x):x_in = self.conv_in(x)x = x_inencode_outputs = []for m in self.encode_modules:x = m(x)encode_outputs.append(x)x = encode_outputs.pop() # 這是移除list最后的一個數據,并且將該數據賦值給x,這里的x是含有空洞卷積的輸出for m in self.decode_modules:x2 = encode_outputs.pop() # 這里是倒數第二深的輸出,x表示下面的,x2表示左邊的x = m(x, x2) # 將下面的,和左邊的一起傳入到上卷積中return x + x_in # 這里是最上面一層進行相加class RSU4F(nn.Module):def __init__(self, in_ch, mid_ch, out_ch):super().__init__()self.conv_in = ConvBNReLU(in_ch, out_ch)self.encode_modules = nn.ModuleList([ConvBNReLU(out_ch, mid_ch),ConvBNReLU(mid_ch, mid_ch, dilation=2),ConvBNReLU(mid_ch, mid_ch, dilation=4),ConvBNReLU(mid_ch, mid_ch, dilation=8)])self.decode_modules = nn.ModuleList([ConvBNReLU(mid_ch*2, mid_ch, dilation=4),ConvBNReLU(mid_ch*2, mid_ch, dilation=2),ConvBNReLU(mid_ch*2, out_ch)])def forward(self, x):x_in = self.conv_in(x)x = x_inencode_outputs = []for m in self.encode_modules:x = m(x)encode_outputs.append(x)x = encode_outputs.pop()for m in self.decode_modules:x2 = encode_outputs.pop()x = m(torch.cat([x, x2], dim=1))return x+x_inclass U2Net(nn.Module):def __init__(self, cfg, out_ch=1):super().__init__()assert "encode" in cfgassert "decode" in cfgself.encode_num = len(cfg["encode"])encode_list = []side_list = []for c in cfg["encode"]:# [height, in_ch, mid_ch, out_ch, RSU4F, side]assert len(c) == 6encode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4])) # 這里的*是將列表解開為單獨的數值,這樣才能傳入到函數中if c[5] is True:side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1))self.encode_modules = nn.ModuleList(encode_list)decode_list = []for c in cfg["decode"]:assert len(c) == 6decode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))if c[5] is True:side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1))self.decode_modules = nn.ModuleList(decode_list)self.side_modules = nn.ModuleList(side_list)self.out_conv = nn.Conv2d(self.encode_num*out_ch, out_ch, kernel_size=1) # 這里是針對cat后的結果進行卷積,得到最后的out_ch=1def forward(self, x):_, _, h, w = x.shapeencode_outputs = []for i, m in enumerate(self.encode_modules):x = m(x)encode_outputs.append(x)if i != self.encode_num - 1: # 除了最后一個encode_block不用下采樣,其余每一個block都需要下采樣x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True)x = encode_outputs.pop()decode_outputs = [x]for m in self.decode_modules:x2 = encode_outputs.pop()x = F.interpolate(x, size=x2.shape[2:], mode="bilinear", align_corners=False)x = m(torch.cat([x, x2], dim=1))decode_outputs.insert(0, x) #這里是保證了從上到下的decode層的輸出,在列表中的遍歷是從0到5side_outputs = []for m in self.side_modules:x = decode_outputs.pop()x = F.interpolate(m(x), size=[h,w], mode="bilinear", align_corners=False)side_outputs.insert(0, x)x = self.out_conv(torch.cat(side_outputs, dim=1))if self.training: # 在訓練的時候,需要將6個輸出都拿出來進行loss計算,return [x] + side_outputselse: # 非訓練時,直接sigmoid后的數據return torch.sigmoid(x)# return torch.sigmoid(x)def u2net_full(in_ch=3, out_ch=1):cfg = {# height, in_ch, mid_ch, out_ch, RSU4F, side"encode": [[7, in_ch, 32, 64, False, False], # En1[6, 64, 32, 128, False, False], # En2[5, 128, 64, 256, False, False], # En3[4, 256, 128, 512, False, False], # En4[4, 512, 256, 512, True, False], # En5[4, 512, 256, 512, True, True]], # En6# height, in_ch, mid_ch, out_ch, RSU4F, side"decode": [[4, 1024, 256, 512, True, True], # De5[4, 1024, 128, 256, False, True], # De4[5, 512, 64, 128, False, True], # De3[6, 256, 32, 64, False, True], # De2[7, 128, 16, 64, False, True]] # De1}return U2Net(cfg, out_ch)def u2net_lite(in_ch=3, out_ch=1):cfg = {# height, in_ch, mid_ch, out_ch, RSU4F, side"encode": [[7, in_ch, 16, 64, False, False], # En1[6, 64, 16, 64, False, False], # En2[5, 64, 16, 64, False, False], # En3[4, 64, 16, 64, False, False], # En4[4, 64, 16, 64, True, False], # En5[4, 64, 16, 64, True, True]], # En6# height, in_ch, mid_ch, out_ch, RSU4F, side"decode": [[4, 128, 16, 64, True, True], # De5[4, 128, 16, 64, False, True], # De4[5, 128, 16, 64, False, True], # De3[6, 128, 16, 64, False, True], # De2[7, 128, 16, 64, False, True]] # De1}return U2Net(cfg, out_ch)# net = u2net_full(1,1) # x = torch.randn(16,1,256,256) # net.eval() # print(net(x))貼一個網絡參數計算代碼:
def count_parameters(model): # 傳入的是模型實例對象params = [p.numel() for p in model.parameters() if p.requires_grad] # for item in params: # print(f'{item:>16}') # 參數大于16的展示print(f'________\n{sum(params):>16}') # 大于16的進行統計,可以自行修改網絡測試:
再說一下,霹導寫的代碼真的很優雅,可以去看霹導的代碼講解和網絡結構講解!!
總結
以上是生活随笔為你收集整理的U2Net——U-Net套U-Net——套娃式图像分割算法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 当下推荐系统的分析和关于长尾效应的解决猜
- 下一篇: mysql 查连接数,查看MySQL的连