ESRGAN:基于GAN的增强超分辨率方法(附代码解析)
作者丨左育莘
學校丨西安電子科技大學
研究方向丨計算機視覺
之前看的文章里有提到 GAN 在圖像修復時更容易得到符合視覺上效果更好的圖像,所以也是看了一些結合 GAN 的圖像修復工作。
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks?發表于 ECCV 2018 的 Workshops,作者在 SRGAN 的基礎上進行了改進,包括改進網絡的結構、判決器的判決形式,以及更換了一個用于計算感知域損失的預訓練網絡。
超分辨率生成對抗網絡(SRGAN)是一項開創性的工作,能夠在單一圖像超分辨率中生成逼真的紋理。這項工作發表于 CVPR 2017。
但是,放大后的細節通常伴隨著令人不快的偽影。為了更進一步地提升視覺質量,作者仔細研究了 SRGAN 的三個關鍵部分:1)網絡結構;2)對抗性損失;3)感知域損失。并對每一項進行改進,得到 ESRGAN。
具體而言,文章提出了一種 Residual-in-Residual Dense Block (RRDB) 的網絡單元,在這個單元中,去掉了 BN(Batch Norm)層。此外,作者借鑒了 Relativistic GAN 的想法,讓判別器預測圖像的真實性而不是圖像“是否是 fake 圖像”。
最后,文章對感知域損失進行改進,使用激活前的特征,這樣可以為亮度一致性和紋理恢復提供更強的監督。在這些改進的幫助下,ESRGAN 得到了更好的視覺質量以及更逼真和自然的紋理。
改進后的效果圖(4 倍放大):
▲?在紋理和細節上,ESRGAN都優于SRGAN
SRGAN的思考與貢獻
現有的超分辨率網絡在不同的網絡結構設計以及訓練策略下,超分辨的效果得到了很大的提升,特別是 PSNR 指標。但是,基于 PSNR 指標的模型會傾向于生成過度平滑的結果,這些結果缺少必要的高頻信息。PSNR 指標與人類觀察者的主觀評價從根本上就不統一。?
一些基于感知域信息驅動的方法已經提出來用于提升超分辨率結果的視覺質量。例如,感知域的損失函數提出來用于在特征空間(instead of 像素空間)中優化超分辨率模型;生成對抗網絡通過鼓勵網絡生成一些更接近于自然圖像的方法來提升超分辨率的質量;語義圖像先驗信息用于進一步改善恢復的紋理細節。?
通過結合上面的方法,SRGAN 模型極大地提升了超分辨率結果的視覺質量。但是 SRGAN 模型得到的圖像和 GT 圖像仍有很大的差距。
ESRGAN的改進
文章對這三點做出改進:
1. 網絡的基本單元從基本的殘差單元變為 Residual-in-Residual Dense Block (RRDB);
2. GAN 網絡改進為 Relativistic average GAN (RaGAN);
3. 改進感知域損失函數,使用激活前的 VGG 特征,這個改進會提供更尖銳的邊緣和更符合視覺的結果。
網絡結構及思想
生成器部分
首先,作者參考 SRResNet 結構作為整體的網絡結構,SRResNet 的基本結構如下:
▲?SRResNet基本結構
為了提升 SRGAN 重構的圖像質量,作者主要對生成器 G 做出如下改變:
1. 去掉所有的 BN 層;
2. 把原始的 block 變為 Residual-in-Residual Dense Block (RRDB),這個 block 結合了多層的殘差網絡和密集連接。
?
如下圖所示:
思想
BN 層的影響:對于不同的基于 PSNR 的任務(包括超分辨率和去模糊)來說,去掉 BN 層已經被證明會提高表現和減小計算復雜度。
BN 層在訓練時,使用一個 batch 的數據的均值和方差對該 batch 特征進行歸一化,在測試時,使用在整個測試集上的數據預測的均值和方差。當訓練集和測試集的統計量有很大不同的時候,BN 層就會傾向于生成不好的偽影,并且限制模型的泛化能力。
作者發現,BN 層在網絡比較深,而且在 GAN 框架下進行訓練的時候,更會產生偽影。這些偽影偶爾出現在迭代和不同的設置中,違反了對訓練穩定性能的需求。所以為了穩定的訓練和一致的性能,作者去掉了 BN 層。此外,去掉 BN 層也能提高模型的泛化能力,減少計算復雜度和內存占用。
Trick
除了上述的改進,作者也使用了一些技巧來訓練深層網絡:
1. 對殘差信息進行 scaling,即將殘差信息乘以一個 0 到 1 之間的數,用于防止不穩定;
2. 更小的初始化,作者發現當初始化參數的方差變小時,殘差結構更容易進行訓練。
判別器部分?
除了改進的生成器,作者也基于 Relativistic GAN 改進了判別器。判別器 D 使用的網絡是 VGG 網絡,SRGAN 中的判別器 D 用于估計輸入到判別器中的圖像是真實且自然圖像的概率,而 Relativistic 判別器則嘗試估計真實圖像相對來說比 fake 圖像更逼真的概率。?
如下圖所示:
具體而言,作者把標準的判別器換成 Relativistic average Discriminator(RaD),所以判別器的損失函數定義為:
對應的生成器的對抗損失函數為:
求均值的操作是通過對 mini-batch 中的所有數據求平均得到的,xf 是原始低分辨圖像經過生成器以后的圖像。
可以觀察到,對抗損失包含了 xr 和 xf,所以這個生成器受益于對抗訓練中的生成數據和實際數據的梯度,這種調整會使得網絡學習到更尖銳的邊緣和更細節的紋理。
感知域損失
文章也提出了一個更有效的感知域損失,使用激活前的特征(VGG16 網絡)。
感知域的損失當前是定義在一個預訓練的深度網絡的激活層,這一層中兩個激活了的特征的距離會被最小化。
與此相反,文章使用的特征是激活前的特征,這樣會克服兩個缺點。第一,激活后的特征是非常稀疏的,特別是在很深的網絡中。這種稀疏的激活提供的監督效果是很弱的,會造成性能低下;第二,使用激活后的特征會導致重建圖像與 GT 的亮度不一致。
如圖所示:
▲?使用激活前與激活后的特征的比較:a. 亮度 b. 細節
作者對使用的感知域損失進行了探索。與目前多數使用的用于圖像分類的 VGG 網絡構建的感知域損失相反,作者提出一種更適合于超分辨的感知域損失,這個損失基于一個用于材料識別的 VGG16 網絡(MINCNet),這個網絡更聚焦于紋理而不是物體。盡管這樣帶來的增益很小,但作者仍然相信,探索關注紋理的感知域損失對超分辨至關重要。
損失函數
經過上面對網絡模塊的定義和構建以后,再定義損失函數,就可以進行訓練了。
對于生成器 G,它的損失函數為:
即為感知域損失,作者的源碼取的是 L1 Loss,即為上面定義的生成器損失,而 L1 則為 pixel-wise 損失,即, 實驗中取。
對于判別器,其損失函數就是上面提到的:
網絡細節
生成器網絡G
要定義 RDDB 模塊,首先要定義 Dense Block,而 Dense Block 里面又有卷積層,LReLU 層以及密集連接,所以首先將卷積層和 LReLU 層進行模塊化,這部分的代碼如下(PyTorch):
???????????????pad_type='zero',?norm_type=None,?act_type='relu',?mode='CNA'):
????'''
????Conv?layer?with?padding,?normalization,?activation
????mode:?CNA?-->?Conv?->?Norm?->?Act
????????NAC?-->?Norm?->?Act?-->?Conv?(Identity?Mappings?in?Deep?Residual?Networks,?ECCV16)
????'''
????assert?mode?in?['CNA',?'NAC',?'CNAC'],?'Wong?conv?mode?[{:s}]'.format(mode)
????padding?=?get_valid_padding(kernel_size,?dilation)
????p?=?pad(pad_type,?padding)?if?pad_type?and?pad_type?!=?'zero'?else?None
????padding?=?padding?if?pad_type?==?'zero'?else?0
????c?=?nn.Conv2d(in_nc,?out_nc,?kernel_size=kernel_size,?stride=stride,?padding=padding,?\
????????????dilation=dilation,?bias=bias,?groups=groups)
????a?=?act(act_type)?if?act_type?else?None
????if?'CNA'?in?mode:
????????n?=?norm(norm_type,?out_nc)?if?norm_type?else?None
????????return?sequential(p,?c,?n,?a)
????elif?mode?==?'NAC':
????????if?norm_type?is?None?and?act_type?is?not?None:
????????????a?=?act(act_type,?inplace=False)
????????????#?Important!
????????????#?input----ReLU(inplace)----Conv--+----output
????????????#????????|________________________|
????????????#?inplace?ReLU?will?modify?the?input,?therefore?wrong?output
????????n?=?norm(norm_type,?in_nc)?if?norm_type?else?None
????????return?sequential(n,?a,?p,?c)
注意這里的 pad_type='zero' 并不是指 padding=0,源碼中定義了兩個函數,針對不同模式下的 padding:
def?pad(pad_type,?padding):
????#?helper?selecting?padding?layer
????#?if?padding?is?'zero',?do?by?conv?layers
????pad_type?=?pad_type.lower()
????if?padding?==?0:
????????return?None
????if?pad_type?==?'reflect':
????????layer?=?nn.ReflectionPad2d(padding)
????elif?pad_type?==?'replicate':
????????layer?=?nn.ReplicationPad2d(padding)
????else:
????????raise?NotImplementedError('padding?layer?[{:s}]?is?not?implemented'.format(pad_type))
????return?layer
def?get_valid_padding(kernel_size,?dilation):
????kernel_size?=?kernel_size?+?(kernel_size?-?1)?*?(dilation?-?1)
????padding?=?(kernel_size?-?1)?//?2
????return?padding
所以當 pad_type='zero' 時,執行的是 get_valid_padding 函數,根據輸入參數可知此時 padding=1。?
模塊化以后,對 Dense Block 進行定義:
????'''
????Residual?Dense?Block
????style:?5?convs
????The?core?module?of?paper:?(Residual?Dense?Network?for?Image?Super-Resolution,?CVPR?18)
????'''
????def?__init__(self,?nc,?kernel_size=3,?gc=32,?stride=1,?bias=True,?pad_type='zero',?\
????????????norm_type=None,?act_type='leakyrelu',?mode='CNA'):
????????super(ResidualDenseBlock_5C,?self).__init__()
????????#?gc:?growth?channel,?i.e.?intermediate?channels
????????self.conv1?=?conv_block(nc,?gc,?kernel_size,?stride,?bias=bias,?pad_type=pad_type,?\
????????????norm_type=norm_type,?act_type=act_type,?mode=mode)
????????self.conv2?=?conv_block(nc+gc,?gc,?kernel_size,?stride,?bias=bias,?pad_type=pad_type,?\
????????????norm_type=norm_type,?act_type=act_type,?mode=mode)
????????self.conv3?=?conv_block(nc+2*gc,?gc,?kernel_size,?stride,?bias=bias,?pad_type=pad_type,?\
????????????norm_type=norm_type,?act_type=act_type,?mode=mode)
????????self.conv4?=?conv_block(nc+3*gc,?gc,?kernel_size,?stride,?bias=bias,?pad_type=pad_type,?\
????????????norm_type=norm_type,?act_type=act_type,?mode=mode)
????????if?mode?==?'CNA':
????????????last_act?=?None
????????else:
????????????last_act?=?act_type
????????self.conv5?=?conv_block(nc+4*gc,?nc,?3,?stride,?bias=bias,?pad_type=pad_type,?\
????????????norm_type=norm_type,?act_type=last_act,?mode=mode)
????def?forward(self,?x):
????????x1?=?self.conv1(x)
????????x2?=?self.conv2(torch.cat((x,?x1),?1))
????????x3?=?self.conv3(torch.cat((x,?x1,?x2),?1))
????????x4?=?self.conv4(torch.cat((x,?x1,?x2,?x3),?1))
????????x5?=?self.conv5(torch.cat((x,?x1,?x2,?x3,?x4),?1))
????????return?x5.mul(0.2)?+?x
前面提到的對殘差信息進行 scaling,在這里可以看出來,系數為 0.2。可以看到在 kernel size(3×3)和 stride=1,padding=1 的設置下,特征圖的大小始終不變,但是通道數由于 concat 的原因,每次都會增加 gc 個通道,但是會在最后一層由變回原來的通道數 nc,這里的參數 norm_type=None,表示不要 Batch Norm。?
定義了 Dense Block 以后,就可以組成 RDDB 了:
????'''
????Residual?in?Residual?Dense?Block
????(ESRGAN:?Enhanced?Super-Resolution?Generative?Adversarial?Networks)
????'''
????def?__init__(self,?nc,?kernel_size=3,?gc=32,?stride=1,?bias=True,?pad_type='zero',?\
????????????norm_type=None,?act_type='leakyrelu',?mode='CNA'):
????????super(RRDB,?self).__init__()
????????self.RDB1?=?ResidualDenseBlock_5C(nc,?kernel_size,?gc,?stride,?bias,?pad_type,?\
????????????norm_type,?act_type,?mode)
????????self.RDB2?=?ResidualDenseBlock_5C(nc,?kernel_size,?gc,?stride,?bias,?pad_type,?\
????????????norm_type,?act_type,?mode)
????????self.RDB3?=?ResidualDenseBlock_5C(nc,?kernel_size,?gc,?stride,?bias,?pad_type,?\
????????????norm_type,?act_type,?mode)
????def?forward(self,?x):
????????out?=?self.RDB1(x)
????????out?=?self.RDB2(out)
????????out?=?self.RDB3(out)
????????return?out.mul(0.2)?+?x
因為特征圖大小始終不變,所以需要定義上采樣模塊進行放大,得到最后的結果:
????????????????pad_type='zero',?norm_type=None,?act_type='relu',?mode='nearest'):
????#?Up?conv
????#?described?in?https://distill.pub/2016/deconv-checkerboard/
????upsample?=?nn.Upsample(scale_factor=upscale_factor,?mode=mode)
????conv?=?conv_block(in_nc,?out_nc,?kernel_size,?stride,?bias=bias,?\
????????????????????????pad_type=pad_type,?norm_type=norm_type,?act_type=act_type)
????return?sequential(upsample,?conv)
參考 SRResNet,還需要一個 Shortcut 連接模塊:
????#Elementwise?sum?the?output?of?a?submodule?to?its?input
????def?__init__(self,?submodule):
????????super(ShortcutBlock,?self).__init__()
????????self.sub?=?submodule
????def?forward(self,?x):
????????output?=?x?+?self.sub(x)
????????return?output
????def?__repr__(self):
????????tmpstr?=?'Identity?+?\n|'
????????modstr?=?self.sub.__repr__().replace('\n',?'\n|')
????????tmpstr?=?tmpstr?+?modstr
????????return?tmpstr
定義好上面的模塊以后,就可以定義生成器網絡 G(RDDBNet):
????def?__init__(self,?in_nc,?out_nc,?nf,?nb,?gc=32,?upscale=4,?norm_type=None,?\
????????????act_type='leakyrelu',?mode='CNA',?upsample_mode='upconv'):
????????super(RRDBNet,?self).__init__()
????????n_upscale?=?int(math.log(upscale,?2))
????????if?upscale?==?3:
????????????n_upscale?=?1
????????fea_conv?=?B.conv_block(in_nc,?nf,?kernel_size=3,?norm_type=None,?act_type=None)
????????rb_blocks?=?[B.RRDB(nf,?kernel_size=3,?gc=32,?stride=1,?bias=True,?pad_type='zero',?\
????????????norm_type=norm_type,?act_type=act_type,?mode='CNA')?for?_?in?range(nb)]
????????LR_conv?=?B.conv_block(nf,?nf,?kernel_size=3,?norm_type=norm_type,?act_type=None,?mode=mode)
????????if?upsample_mode?==?'upconv':
????????????upsample_block?=?B.upconv_blcok
????????elif?upsample_mode?==?'pixelshuffle':
????????????upsample_block?=?B.pixelshuffle_block
????????else:
????????????raise?NotImplementedError('upsample?mode?[{:s}]?is?not?found'.format(upsample_mode))
????????if?upscale?==?3:
????????????upsampler?=?upsample_block(nf,?nf,?3,?act_type=act_type)
????????else:
????????????upsampler?=?[upsample_block(nf,?nf,?act_type=act_type)?for?_?in?range(n_upscale)]
????????HR_conv0?=?B.conv_block(nf,?nf,?kernel_size=3,?norm_type=None,?act_type=act_type)
????????HR_conv1?=?B.conv_block(nf,?out_nc,?kernel_size=3,?norm_type=None,?act_type=None)
????????self.model?=?B.sequential(fea_conv,?B.ShortcutBlock(B.sequential(*rb_blocks,?LR_conv)),\
????????????*upsampler,?HR_conv0,?HR_conv1)
????def?forward(self,?x):
????????x?=?self.model(x)
????????return?x
注意到這里有個參數 nb,這個參數控制網絡中 RDDB 的數量,作者取的是 23。
判別器網絡D
前面提到,判別器 D 的網絡結構為 VGG 網絡,定義如下(輸入圖像 size 為 128×128):
class?Discriminator_VGG_128(nn.Module):
????def?__init__(self,?in_nc,?base_nf,?norm_type='batch',?act_type='leakyrelu',?mode='CNA'):
????????super(Discriminator_VGG_128,?self).__init__()
????????#?features
????????#?hxw,?c
????????#?128,?64
????????conv0?=?B.conv_block(in_nc,?base_nf,?kernel_size=3,?norm_type=None,?act_type=act_type,?\
????????????mode=mode)
????????conv1?=?B.conv_block(base_nf,?base_nf,?kernel_size=4,?stride=2,?norm_type=norm_type,?\
????????????act_type=act_type,?mode=mode)
????????#?64,?64
????????conv2?=?B.conv_block(base_nf,?base_nf*2,?kernel_size=3,?stride=1,?norm_type=norm_type,?\
????????????act_type=act_type,?mode=mode)
????????conv3?=?B.conv_block(base_nf*2,?base_nf*2,?kernel_size=4,?stride=2,?norm_type=norm_type,?\
????????????act_type=act_type,?mode=mode)
????????#?32,?128
????????conv4?=?B.conv_block(base_nf*2,?base_nf*4,?kernel_size=3,?stride=1,?norm_type=norm_type,?\
????????????act_type=act_type,?mode=mode)
????????conv5?=?B.conv_block(base_nf*4,?base_nf*4,?kernel_size=4,?stride=2,?norm_type=norm_type,?\
????????????act_type=act_type,?mode=mode)
????????#?16,?256
????????conv6?=?B.conv_block(base_nf*4,?base_nf*8,?kernel_size=3,?stride=1,?norm_type=norm_type,?\
????????????act_type=act_type,?mode=mode)
????????conv7?=?B.conv_block(base_nf*8,?base_nf*8,?kernel_size=4,?stride=2,?norm_type=norm_type,?\
????????????act_type=act_type,?mode=mode)
????????#?8,?512
????????conv8?=?B.conv_block(base_nf*8,?base_nf*8,?kernel_size=3,?stride=1,?norm_type=norm_type,?\
????????????act_type=act_type,?mode=mode)
????????conv9?=?B.conv_block(base_nf*8,?base_nf*8,?kernel_size=4,?stride=2,?norm_type=norm_type,?\
????????????act_type=act_type,?mode=mode)
????????#?4,?512
????????self.features?=?B.sequential(conv0,?conv1,?conv2,?conv3,?conv4,?conv5,?conv6,?conv7,?conv8,\
????????????conv9)
????????#?classifier
????????self.classifier?=?nn.Sequential(
????????????nn.Linear(512?*?4?*?4,?100),?nn.LeakyReLU(0.2,?True),?nn.Linear(100,?1))
????def?forward(self,?x):
????????x?=?self.features(x)
????????x?=?x.view(x.size(0),?-1)
????????x?=?self.classifier(x)
????????return?x
可以看到,這里使用了 batch norm,層間的激活函數為 leakyReLU,base_nf參數為基礎通道數,為 64。經過特征提取以后,原本為 128×128×1(/3) 的輸入圖像輸出為 4×4×512。再經過其定義的 classifier 得到輸出值。?
提取感知域損失的網絡(Perceptual Network)?
文章使用了一個用于材料識別的 VGG16 網絡(MINCNet)來提取感知域特征,定義如下:
????def?__init__(self):
????????super(MINCNet,?self).__init__()
????????self.ReLU?=?nn.ReLU(True)
????????self.conv11?=?nn.Conv2d(3,?64,?3,?1,?1)
????????self.conv12?=?nn.Conv2d(64,?64,?3,?1,?1)
????????self.maxpool1?=?nn.MaxPool2d(2,?stride=2,?padding=0,?ceil_mode=True)
????????self.conv21?=?nn.Conv2d(64,?128,?3,?1,?1)
????????self.conv22?=?nn.Conv2d(128,?128,?3,?1,?1)
????????self.maxpool2?=?nn.MaxPool2d(2,?stride=2,?padding=0,?ceil_mode=True)
????????self.conv31?=?nn.Conv2d(128,?256,?3,?1,?1)
????????self.conv32?=?nn.Conv2d(256,?256,?3,?1,?1)
????????self.conv33?=?nn.Conv2d(256,?256,?3,?1,?1)
????????self.maxpool3?=?nn.MaxPool2d(2,?stride=2,?padding=0,?ceil_mode=True)
????????self.conv41?=?nn.Conv2d(256,?512,?3,?1,?1)
????????self.conv42?=?nn.Conv2d(512,?512,?3,?1,?1)
????????self.conv43?=?nn.Conv2d(512,?512,?3,?1,?1)
????????self.maxpool4?=?nn.MaxPool2d(2,?stride=2,?padding=0,?ceil_mode=True)
????????self.conv51?=?nn.Conv2d(512,?512,?3,?1,?1)
????????self.conv52?=?nn.Conv2d(512,?512,?3,?1,?1)
????????self.conv53?=?nn.Conv2d(512,?512,?3,?1,?1)
????def?forward(self,?x):
????????out?=?self.ReLU(self.conv11(x))
????????out?=?self.ReLU(self.conv12(out))
????????out?=?self.maxpool1(out)
????????out?=?self.ReLU(self.conv21(out))
????????out?=?self.ReLU(self.conv22(out))
????????out?=?self.maxpool2(out)
????????out?=?self.ReLU(self.conv31(out))
????????out?=?self.ReLU(self.conv32(out))
????????out?=?self.ReLU(self.conv33(out))
????????out?=?self.maxpool3(out)
????????out?=?self.ReLU(self.conv41(out))
????????out?=?self.ReLU(self.conv42(out))
????????out?=?self.ReLU(self.conv43(out))
????????out?=?self.maxpool4(out)
????????out?=?self.ReLU(self.conv51(out))
????????out?=?self.ReLU(self.conv52(out))
????????out?=?self.conv53(out)
????????return?out
再引入預訓練參數,就可以進行特征提取:
????def?__init__(self,?feature_layer=34,?use_bn=False,?use_input_norm=True,?\
????????????????device=torch.device('cpu')):
????????super(MINCFeatureExtractor,?self).__init__()
????????self.features?=?MINCNet()
????????self.features.load_state_dict(
????????????torch.load('../experiments/pretrained_models/VGG16minc_53.pth'),?strict=True)
????????self.features.eval()
????????#?No?need?to?BP?to?variable
????????for?k,?v?in?self.features.named_parameters():
????????????v.requires_grad?=?False
????def?forward(self,?x):
????????output?=?self.features(x)
????????return?output
網絡插值思想
為了平衡感知質量和 PSNR 等評價值,作者提出了一個靈活且有效的方法——網絡插值。具體而言,作者首先基于 PSNR 方法訓練的得到的網絡 G_PSNR,然后再用基于 GAN 的網絡 G_GAN 進行 finetune。
然后,對這兩個網絡相應的網絡參數進行插值得到一個插值后的網絡 G_INTERP:
這樣就可以通過 α 值來調整效果。
訓練細節
放大倍數:4;mini-batch:16。
通過 Matlab 的 bicubic 函數對 HR 圖像進行降采樣得到 LR 圖像。?
HR patch 大小:128×128。實驗發現使用大的 patch 時,訓練一個深層網絡效果會更好,因為一個增大的感受域會幫助模型捕捉更具有語義的信息。
訓練過程如下:?
1. 訓練一個基于 PSNR 指標的模型(L1 Loss),初始化學習率:2×1e-4,每 200000 個 mini-batch 學習率除以 2;
2. 以 1 中訓練的模型作為生成器的初始化。
初始學習率:1e-4,并在 50k,100k,200k,300k 迭代后減半。
一個基于像素損失函數進行優化的預訓練模型會幫助基于 GAN 的模型生成更符合視覺的結果,原因如下:
1. 可以避免生成器不希望的局部最優;
2. 再預訓練以后,判別器所得到的輸入圖像的質量是相對較好的,而不是完全初始化的圖像,這樣會使判別器更關注到紋理的判別。
優化器:Adam (β1=0.9, β2=0.999);交替更新生成器和判別器,直到收斂。?
生成器的設置:1.16 層(基本的殘差結構);2.23層(RDDB)。
數據集:DIV2K,Flickr2K,OST(有豐富紋理信息的數據集會是模型產生更自然的結果)。
對比實驗(4倍放大)
針對文中提到的各種改進,包括移除 BN,使用激活前特征作為感知域特征,修改 GAN 的判別條件,以及提出的 RDDB,作者做了詳細的對比試驗,結果如下:
經過實驗以后,作者得出結論:?
1. 去掉 BN:并沒有降低網絡的性能,而且節省了計算資源和內存占用。而且發現當網絡變深變復雜時,帶 BN 層的模型更傾向于產生影響視覺效果的偽影;
2. 使用激活前的特征:得到的圖像的亮度更準確,而且可以產生更尖銳的邊緣和更豐富的細節;
3. RaGAN:產生更尖銳的邊緣和更豐富的細節;
4. RDDB:更加提升恢復得到的紋理(因為深度模型具有強大的表示能力來捕獲語義信息),而且可以去除噪聲。
網絡插值實驗?
為了平衡視覺效果和 PSNR 等性能指標,作者對網絡插值參數 α 的取值進行了實驗,結果如下:
此外,作者還對比了網絡插值和圖像插值的效果。圖像插值即指分別由兩個網絡輸出的圖像進行插值。通過對比實驗可以看到,圖像插值對消除偽影的效果不如網絡插值。
與SOTA方法對比(4倍放大)
可以看到,ESRGAN 得到的圖像 PSNR 值不高,但是從視覺效果上看會更好,Percpetual Index 值更小(越小越好),而且 ESRGAN 在 PIRM-SR 競賽上也獲得了第一名(在 Percpetual Index 指標上)。
總結
文章提出的 ESRGAN 在 SRGAN 的基礎上做出了改進,包括去除 BN 層,基本結構換成 RDDB,改進 GAN 中判別器的判別目標,以及使用激活前的特征構成感知域損失函數,實驗證明這些改進對提升輸出圖像的視覺效果都有作用。
此外,作者也使用了一些技巧來提升網絡的性能,包括對殘差信息的 scaling,以及更小的初始化。最后,作者使用了一種網絡插值的方法來平衡輸出圖像的視覺效果和 PSNR 等指標值。
#拉 票 時 間#
?請為PaperWeekly打call!?
PaperWeekly 曾于 2016 年入選“年度十大學術公眾號”(2016年度學術公眾號TOP10重磅發布),并連續入圍“2017 年度學術公眾號”評選。?
如果您喜歡 PaperWeekly,請在本文底部點擊“閱讀原文”-勾選“PaperWeekly”,為我們投上您寶貴的一票。
?
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
點擊“閱讀原文”,為“PaperWeekly”投票!
總結
以上是生活随笔為你收集整理的ESRGAN:基于GAN的增强超分辨率方法(附代码解析)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 知识图谱最新论文清单,高阶炼丹师为你逐一
- 下一篇: 一年学遍吴恩达、李飞飞、周志华等16大精