对抗生成网络GAN系列——GANomaly原理及源码解析
🍊作者簡介:禿頭小蘇,致力于用最通俗的語言描述問題
🍊往期回顧:對抗生成網絡GAN系列——GAN原理及手寫數字生成小案例 ??對抗生成網絡GAN系列——DCGAN簡介及人臉圖像生成案例 ??對抗生成網絡GAN系列——AnoGAN原理及缺陷檢測實戰 ??對抗生成網絡GAN系列——EGBAD原理及缺陷檢測實戰
🍊近期目標:寫好專欄的每一篇文章
🍊支持小蘇:點贊👍🏼、收藏?、留言📩
?
文章目錄
- 對抗生成網絡GAN系列——GANomaly原理及源碼解析
- 寫在前面
- GANomaly原理解析
- GANomaly結構
- GANomaly損失函數
- GANomaly測試階段
- GANomaly源碼解析
- GANomaly模型搭建
- GANomaly損失函數
- 小結
- 參考鏈接
對抗生成網絡GAN系列——GANomaly原理及源碼解析
寫在前面
? 在前面,我已經介紹過好幾篇有關GAN的文章,鏈接如下:
- [1]對抗生成網絡GAN系列——GAN原理及手寫數字生成小案例 🍁🍁🍁
- [2]對抗生成網絡GAN系列——DCGAN簡介及人臉圖像生成案例🍁🍁🍁
- [3]對抗生成網絡GAN系列——CycleGAN原理🍁🍁🍁
- [4] 對抗生成網絡GAN系列——AnoGAN原理及缺陷檢測實戰 🍁🍁🍁
- [5]對抗生成網絡GAN系列——EGBAD原理及缺陷檢測實戰🍁🍁🍁
- [6]對抗生成網絡GAN系列——WGAN原理及實戰演練🍁🍁🍁
??這篇文章我將來為大家介紹GANomaly,論文名為:Semi-Supervised Anomaly Detection via Adversarial Training。這篇文章同樣是實現缺陷檢測的,因此在閱讀本文之前建議你對使用GAN網絡實現缺陷檢測有一定的了解,可以參考上文鏈接中的[4]和[5]。
??準備好了嗎,嘟嘟嘟,開始發車。🚖🚖🚖
?
GANomaly原理解析
【閱讀此部分前建議對GAN的原理及GAN在缺陷檢測上的應用有所了解,詳情點擊寫在前面中的鏈接查看,本篇文章我不會再介紹GAN的一些先驗知識。】
GANomaly結構
? 這部分為大家介紹GANomaly的原理,其實我們一起來看下圖就足夠了:
?
圖1 GANomaly結構圖??我們還是先來對上圖中的結構做一些解釋。從直觀的顏色上來看,我們可以分成兩類,一類是紅色的Encoder結構,一類是藍色的Decoder結構。Encoder主要就是降維的作用啦,如將一張張圖片數據壓縮成一個個潛在向量;相反,Decoder就是升維的作用,如將一個個潛在向量重建成一張張圖片。按照論文描述的結構來分,可以分成三個子結構,分別為生成器網絡G,編碼器網絡E和判別器網絡D。下面分別來介紹介紹這三個子結構:
-
生成器網絡G
??生成器網絡G由兩個部分組成,分別為編碼器GE(x))G_E(x))GE?(x))和解碼器GD(z)G_D(z)GD?(z),其實這就是一個自動編碼器結構,主要用來學習輸入x的數據分布并重建圖像x^{\hat x}x^。我們一個個來看,先看GE(x)G_E(x)GE?(x)結構,假設我們的輸入x維度為RC×H×W\mathbb{R}^{C×H×W}RC×H×W,經過GE(x)G_E(x)GE?(x)結構后,變成一個向量zzz,其維度為Rd\mathbb{R}^dRd。【GE(x)G_E(x)GE?(x)具體結構很簡單啦,這里就不詳細介紹了。我會在源碼解析部分給出,大家肯定一看就會。】接著我們來看GD(z)G_D(z)GD?(z)結構,它會將剛剛得到的向量z上采樣成x^\hat xx^,x^\hat xx^的維度和xxx一致,都為RC×H×W\mathbb{R}^{C×H×W}RC×H×W。關于GD(Z)G_D(Z)GD?(Z)結構也很簡單,其主要用到了轉置卷積,對于轉置卷積不了解的可以看博客[2]了解詳情。生成器網絡G就為大家介紹完了,是不是發現很簡單呢。總結下來就兩步,第一步讓輸入x通過GE(x)G_E(x)GE?(x)得到z,第二步讓z通過GD(Z)G_D(Z)GD?(Z)變成x^\hat xx^。這兩步也可以用一步表示,即x^=G(x)\hat x=G(x)x^=G(x)。??思來想去我還是想在這里給大家拋出一個問題,我們傳統的GAN是怎么通過生成器來構建假圖像的呢?和GANomaly有區別嗎?其實這個問題的答案很簡單,大家都稍稍思考一下,我就不給答案了,不明白的評論區見吧!!!🥂🥂🥂
?
-
編碼器網絡E
? ??編碼器網絡E的作用是將生成器得到的x^\hat xx^壓縮成一個向量z^\hat zz^,是不是發現和生成器網絡中的GE(x)G_E(x)GE?(x)很像呢,其實呀,它倆的結構就是完全一樣的,生成的z^\hat zz^ 和x^\hat xx^ 的維度一致,這是方便后面的損失比較。
?
-
判別器網絡D
? ??判別器網絡D和我們之前介紹DCGAN時的結構是一樣的,都是將真實數據xxx和生成數據x^\hat xx^輸入網絡,然后得出一個分數。
?
GANomaly損失函數
??GANomaly的損失函數分為兩部分,第一部分是生成器損失,第二部分為判別器損失,下面我們分別來進行介紹:
-
生成器損失函數
? 生成器損失函數又由三個部分組成,分別如下:
-
Adversari Loss
我還是直接上公式吧,如下:
? Ladv=Ex~px∣∣f(x)?Ex~pxf(G(x))∣∣2L_{adv}=E_{x \sim px}||f(x)-E_{x \sim px}f(G(x))||_2Ladv?=Ex~px?∣∣f(x)?Ex~px?f(G(x))∣∣2?
這個公式對應圖一中的Ladv=∣∣f(x)?f(x^)∣∣2L_{adv}=||f(x)-f(\hat x)||_2Ladv?=∣∣f(x)?f(x^)∣∣2?🍵🍵🍵這個損失函數應該很好理解,在前面介紹的GAN網絡都有提及,f(?)f(*)f(?)表示判別器網絡某個中間層的輸出。這個損失函數的作用就是讓兩張圖像x和x^x和\hat xx和x^盡可能接近,也就是讓生成器生成的圖片更加逼真。
-
Contextual Loss
同樣的,直接來上公式,如下:
? Lcon=Ex~px∣∣x?G(x)∣∣1L_{con}=E_{x \sim px}||x-G(x)||_1Lcon?=Ex~px?∣∣x?G(x)∣∣1?
這個公式對應圖一中的Lcon=∣∣x?x^∣∣1L_{con}=||x-\hat x||_1Lcon?=∣∣x?x^∣∣1?🍵🍵🍵這個函數其實也是要讓兩張圖像x和x^x和\hat xx和x^盡可能接近。至于這里為什么用的是L1范數而不是L2范數,作者在論文中說這里使用L1范數的效果要比使用L2范數的效果好,這屬于實驗得到的結論,大家也不用過于糾結。
-
Encoder Loss
話不多說,上公式,如下:
? Lenc=Ex~px∣∣GE(x)?E(G(x))∣∣2L_{enc}=E_{x \sim px}||G_E(x)-E(G(x))||_2Lenc?=Ex~px?∣∣GE?(x)?E(G(x))∣∣2?
這個公式對應圖一中的Lenc=∣∣z?z^∣∣2L_{enc}=||z-\hat z||_2Lenc?=∣∣z?z^∣∣2?🍵🍵🍵這里的損失函數在我看來主要作用就是讓我們在推理過程中的效果更好,這里就像AnoGAN中不斷搜索最優的那個z的作用。
如果大家這里讀過cycleGAN的論文的話,可能會覺得這個損失函數有點類似cycleGAN中的循環一致性損失。我覺得這篇文章的思想可能借鑒了cycleGAN中的思想,感興趣的可以去閱讀一下,非常有意思的一篇文章!!!🥃🥃🥃
生成器總的損失是上述三種損失的加權和,如下:
L=wadvLadv+wconLcon+wencLencL=w_{adv}L_{adv}+w_{con}L_{con}+w_{enc}L_{enc}L=wadv?Ladv?+wcon?Lcon?+wenc?Lenc?在論文提供的源碼中,默認wcon=50,wadv=wenc=1w_{con}=50,w_{adv}=w_{enc}=1wcon?=50,wadv?=wenc?=1。
?
-
-
判別器損失函數
判別器的損失函數就和原始GAN一樣,如下:【不清楚的點擊???了解詳情】
這部分我直接先放上代碼吧,不多,也很容易理解,如下:
self.l_bce = nn.BCELoss() # Real - Fake Loss self.err_d_real = self.l_bce(self.pred_real, self.real_label) self.err_d_fake = self.l_bce(self.pred_fake, self.fake_label)# NetD Loss & Backward-Pass self.err_d = (self.err_d_real + self.err_d_fake) * 0.5
?
GANomaly測試階段
??在上一小節,為大家介紹了GANomaly的損失函數,這是在測試階段使用的。GANomaly針對的是異常檢測任務,在測試階段我們會對輸入的數據進行評分,根據評分的結果來判定輸入是否異常。在GANomaly中使用的評分函數就是我們上一小節介紹的Encoder Loss,對于一個測試數據x,用A(x)A(x)A(x)表示其異常得分,則:
? A(x)=∣∣GE(x)?E(G(x))∣∣2A(x)=||G_E(x)-E(G(x))||_2A(x)=∣∣GE?(x)?E(G(x))∣∣2?
??這里大家需要注意以下,論文中A(x)A(x)A(x)的表達式使用的是L1范數,但是從我閱讀論文提供的源碼來看,代碼中使用的是L2范數。這里保持和源碼一致,使用L2范數。代碼中關于此部分的描述如下:
# latent_i表示G_E(x),latent_o表示E(G(x))。torch.pow(m,2)=m^2 error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1)?
?
GANomaly源碼解析
? 這里直接使用論文中提供的源碼地址:GANomaly源碼🌱🌱🌱
GANomaly模型搭建
???其實通過我前文的講解,不知道大家能否感受到GANomaly模型其實是不復雜的。需要注意的是在介紹GANomaly結構時我們將模型分為了三個子結構,分別為生成器網絡G、編碼器網絡E、判別器網絡D。但是在代碼中我們將生成器網絡G和編碼器網絡E合并在一塊兒了,也稱為生成器網絡G。
??下面我給出這部分的代碼,大家注意一下這里面的超參數比較多,為了方便大家閱讀,我把這里用到超參數的整理出來,如下圖所示:
""" Network architectures. """ # pylint: disable=W0221,W0622,C0103,R0913## import torch import torch.nn as nn import torch.nn.parallel from options import Options## def weights_init(mod):"""Custom weights initialization called on netG, netD and netE:param m::return:"""classname = mod.__class__.__name__if classname.find('Conv') != -1:mod.weight.data.normal_(0.0, 0.02)elif classname.find('BatchNorm') != -1:mod.weight.data.normal_(1.0, 0.02)mod.bias.data.fill_(0)### class Encoder(nn.Module):"""DCGAN ENCODER NETWORK"""def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0, add_final_conv=True):super(Encoder, self).__init__()self.ngpu = ngpuassert isize % 16 == 0, "isize has to be a multiple of 16"main = nn.Sequential()# input is nc x isize x isizemain.add_module('initial-conv-{0}-{1}'.format(nc, ndf),nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))main.add_module('initial-relu-{0}'.format(ndf),nn.LeakyReLU(0.2, inplace=True))csize, cndf = isize / 2, ndf # csize=16,cndf=64# Extra layersfor t in range(n_extra_layers):main.add_module('extra-layers-{0}-{1}-conv'.format(t, cndf),nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False))main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cndf),nn.BatchNorm2d(cndf))main.add_module('extra-layers-{0}-{1}-relu'.format(t, cndf),nn.LeakyReLU(0.2, inplace=True))while csize > 4:in_feat = cndfout_feat = cndf * 2main.add_module('pyramid-{0}-{1}-conv'.format(in_feat, out_feat),nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False))main.add_module('pyramid-{0}-batchnorm'.format(out_feat),nn.BatchNorm2d(out_feat))main.add_module('pyramid-{0}-relu'.format(out_feat),nn.LeakyReLU(0.2, inplace=True))cndf = cndf * 2csize = csize / 2# state size. K x 4 x 4if add_final_conv:main.add_module('final-{0}-{1}-conv'.format(cndf, 1),nn.Conv2d(cndf, nz, 4, 1, 0, bias=False))self.main = maindef forward(self, input):if self.ngpu > 1:output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))else:output = self.main(input)return output## class Decoder(nn.Module):"""DCGAN DECODER NETWORK"""def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0):super(Decoder, self).__init__()self.ngpu = ngpuassert isize % 16 == 0, "isize has to be a multiple of 16"cngf, tisize = ngf // 2, 4 #cngf=32 ,tisize=4while tisize != isize:cngf = cngf * 2tisize = tisize * 2main = nn.Sequential()# input is Z, going into a convolutionmain.add_module('initial-{0}-{1}-convt'.format(nz, cngf),nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False))main.add_module('initial-{0}-batchnorm'.format(cngf),nn.BatchNorm2d(cngf))main.add_module('initial-{0}-relu'.format(cngf),nn.ReLU(True))csize, _ = 4, cngfwhile csize < isize // 2:main.add_module('pyramid-{0}-{1}-convt'.format(cngf, cngf // 2),nn.ConvTranspose2d(cngf, cngf // 2, 4, 2, 1, bias=False))main.add_module('pyramid-{0}-batchnorm'.format(cngf // 2),nn.BatchNorm2d(cngf // 2))main.add_module('pyramid-{0}-relu'.format(cngf // 2),nn.ReLU(True))cngf = cngf // 2csize = csize * 2# Extra layersfor t in range(n_extra_layers):main.add_module('extra-layers-{0}-{1}-conv'.format(t, cngf),nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False))main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cngf),nn.BatchNorm2d(cngf))main.add_module('extra-layers-{0}-{1}-relu'.format(t, cngf),nn.ReLU(True))main.add_module('final-{0}-{1}-convt'.format(cngf, nc),nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False))main.add_module('final-{0}-tanh'.format(nc),nn.Tanh())self.main = maindef forward(self, input):if self.ngpu > 1:output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))else:output = self.main(input)return output## 判別器網絡結構 class NetD(nn.Module):"""DISCRIMINATOR NETWORK"""def __init__(self, opt):super(NetD, self).__init__()model = Encoder(opt.isize, 1, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)layers = list(model.main.children())self.features = nn.Sequential(*layers[:-1])self.classifier = nn.Sequential(layers[-1])self.classifier.add_module('Sigmoid', nn.Sigmoid())def forward(self, x):features = self.features(x)features = featuresclassifier = self.classifier(features)classifier = classifier.view(-1, 1).squeeze(1)return classifier, features## 生成器網絡結構 class NetG(nn.Module):"""GENERATOR NETWORK"""def __init__(self, opt):super(NetG, self).__init__()self.encoder1 = Encoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)self.decoder = Decoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)self.encoder2 = Encoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)def forward(self, x):latent_i = self.encoder1(x)gen_imag = self.decoder(latent_i)latent_o = self.encoder2(gen_imag)return gen_imag, latent_i, latent_o?
GANomaly損失函數
???我們在理論部分已經介紹了GANomaly的損失函數,那么在代碼上它們都是一一對應的,實現起來也很簡單,如下:
## 定義L1 Loss def l1_loss(input, target):return torch.mean(torch.abs(input - target))## 定義L2 Loss def l2_loss(input, target, size_average=True):if size_average:return torch.mean(torch.pow((input-target), 2))else:return torch.pow((input-target), 2)self.l_adv = l2_loss self.l_con = nn.L1Loss() self.l_enc = l2_lossself.err_g_adv = self.l_adv(self.netd(self.input)[1], self.netd(self.fake)[1]) self.err_g_con = self.l_con(self.fake, self.input) self.err_g_enc = self.l_enc(self.latent_o, self.latent_i) self.err_g = self.err_g_adv * self.opt.w_adv + \self.err_g_con * self.opt.w_con + \self.err_g_enc * self.opt.w_enc??上述代碼為GANomaly生成器損失函數代碼,判別器的損失函數代碼已經在理論部分為大家介紹了,這里就不在贅述了。🍄🍄🍄
?
?
小結
??這里我并沒有很詳細的為大家解讀代碼,但是把一些關鍵的部分都給大家介紹了。會了這些其實你完全可以自己實現一個GANomaly網絡,或者對我之前在Anogan中的代碼稍加改造也可以達到一樣的效果。論文中提供的源碼感興趣的大家可以自己去調試一下,代碼量也不算多,但有的地方理解起來也有一定的困難,總之大家加油吧!!!🌼🌼🌼
?
參考鏈接
GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training 🍁🍁🍁
GANomaly 異常檢測的經典之作|ACCV 2018 🍁🍁🍁
?
?
如若文章對你有所幫助,那就🛴🛴🛴
總結
以上是生活随笔為你收集整理的对抗生成网络GAN系列——GANomaly原理及源码解析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: LiveGBS流媒体平台GB/T2818
- 下一篇: 携程行程订单团队敏捷之旅