文章目錄
一、主要思想
提出了pyramid pooling module (PPM) 模塊,聚合不同區域的上下文信息,從而提高獲取全局信息的能力。
現有的深度網絡方法中,某一個操作的感受野直接決定了這個操作可以獲得多少上下文信息,所以提升感受野可以為網絡引入更多的上下文信息。
二、方法
Step1:使用global averag pooling得到不同尺度的特征,PPM模塊融合了4個不同尺度的特征:
- 紅色是最粗糙尺度,使用一個global average pooling 實現
- 其他的都是將特征圖切分為不同數量的塊,在每個塊內使用global average pooling (文中四個尺度分別是 1x1, 2x2, 3x3, 6x6)
Step2:global average pooling 之后,每層都接一個1x1的卷積來降低通道維度。
Step3:上采樣到和原圖相同的尺寸,然后和進入PPM頭之前的feature map 進行concat 來預測結果。
import torch
import torch
.nn
as nn
from mmcv
.cnn
import ConvModule
from mmseg
.ops
import resize
from ..builder
import HEADS
from .decode_head
import BaseDecodeHead
from .Attention_layer
import HardClassAttention
as HCA
class PPM(nn
.ModuleList
):"""Pooling Pyramid Module used in PSPNet.Args:pool_scales (tuple[int]): Pooling scales used in Pooling PyramidModule.in_channels (int): Input channels.channels (int): Channels after modules, before conv_seg.conv_cfg (dict|None): Config of conv layers.norm_cfg (dict|None): Config of norm layers.act_cfg (dict): Config of activation layers.align_corners (bool): align_corners argument of F.interpolate."""def __init__(self
, pool_scales
, in_channels
, channels
, conv_cfg
, norm_cfg
,act_cfg
, align_corners
):super(PPM
, self
).__init__
()self
.pool_scales
= pool_scalesself
.align_corners
= align_cornersself
.in_channels
= in_channelsself
.channels
= channelsself
.conv_cfg
= conv_cfgself
.norm_cfg
= norm_cfgself
.act_cfg
= act_cfg
for pool_scale
in pool_scales
:self
.append
(nn
.Sequential
(nn
.AdaptiveAvgPool2d
(pool_scale
),ConvModule
(self
.in_channels
,self
.channels
,1,conv_cfg
=self
.conv_cfg
,norm_cfg
=self
.norm_cfg
,act_cfg
=self
.act_cfg
)))def forward(self
, x
):"""Forward function."""ppm_outs
= []for ppm
in self
:ppm_out
= ppm
(x
)upsampled_ppm_out
= resize
(ppm_out
,size
=x
.size
()[2:],mode
='bilinear',align_corners
=self
.align_corners
)ppm_outs
.append
(upsampled_ppm_out
)return ppm_outs@HEADS
.register_module
()
class PSPHead(BaseDecodeHead
):"""Pyramid Scene Parsing Network.This head is the implementation of`PSPNet <https://arxiv.org/abs/1612.01105>`_.Args:pool_scales (tuple[int]): Pooling scales used in Pooling PyramidModule. Default: (1, 2, 3, 6)."""def __init__(self
, pool_scales
=(1, 2, 3, 6), **kwargs
):super(PSPHead
, self
).__init__
(**kwargs
)assert isinstance(pool_scales
, (list, tuple))self
.pool_scales
= pool_scalesself
.psp_modules
= PPM
(self
.pool_scales
,self
.in_channels
,self
.channels
,conv_cfg
=self
.conv_cfg
,norm_cfg
=self
.norm_cfg
,act_cfg
=self
.act_cfg
,align_corners
=self
.align_corners
)self
.bottleneck
= ConvModule
(self
.in_channels
+ len(pool_scales
) * self
.channels
,self
.channels
,3,padding
=1,conv_cfg
=self
.conv_cfg
,norm_cfg
=self
.norm_cfg
,act_cfg
=self
.act_cfg
)def forward(self
, inputs
):"""Forward function."""x
= self
._transform_inputs
(inputs
) psp_outs
= [x
] psp_outs
.extend
(self
.psp_modules
(x
)) psp_outs
= torch
.cat
(psp_outs
, dim
=1) output
= self
.bottleneck
(psp_outs
) output
= self
.cls_seg
(output
) return output
總結
以上是生活随笔為你收集整理的【语义分割】PSPNet:Pyramid Scene Parsing Network的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。