[GAN学习系列] 初识GAN
本文大約 3800 字,閱讀大約需要 8 分鐘
要說最近幾年在深度學習領域最火的莫過于生成對抗網絡,即 Generative Adversarial Networks(GANs)了。它是 Ian Goodfellow 在 2014 年發表的,也是這四年來出現的各種 GAN 的變種的開山鼻祖了,下圖表示這四年來有關 GAN 的論文的每個月發表數量,可以看出在 2014 年提出后到 2016 年相關的論文是比較少的,但是從 2016 年,或者是 2017 年到今年這兩年的時間,相關的論文是真的呈現井噴式增長。
那么,GAN 究竟是什么呢,它為何會成為這幾年這么火的一個研究領域呢?
GAN,即生成對抗網絡,是一個生成模型,也是半監督和無監督學習模型,它可以在不需要大量標注數據的情況下學習深度表征。最大的特點就是提出了一種讓兩個深度網絡對抗訓練的方法。
目前機器學習按照數據集是否有標簽可以分為三種,監督學習、半監督學習和無監督學習,發展最成熟,效果最好的目前還是監督學習的方法,但是在數據集數量要求更多更大的情況下,獲取標簽的成本也更加昂貴了,因此越來越多的研究人員都希望能夠在無監督學習方面有更好的發展,而 GAN 的出現,一來它是不太需要很多標注數據,甚至可以不需要標簽,二來它可以做到很多事情,目前對它的應用包括圖像合成、圖像編輯、風格遷移、圖像超分辨率以及圖像轉換等。
比如字體的轉換,在 zi2zi 這個項目中,給出了對中文文字的字體的變換,效果如下圖所示,GAN 可以學習到不同字體,然后將其進行變換。
除了字體的學習,還有對圖片的轉換, pix2pix 就可以做到,其結果如下圖所示,分割圖變成真實照片,從黑白圖變成彩色圖,從線條畫變成富含紋理、陰影和光澤的圖等等,這些都是這個 pix2pixGAN 實現的結果。
CycleGAN 則可以做到風格遷移,其實現結果如下圖所示,真實照片變成印象畫,普通的馬和斑馬的互換,季節的變換等。
上述是 GAN 的一些應用例子,接下來會簡單介紹 GAN 的原理以及其優缺點,當然也還有為啥等它提出兩年后才開始有越來越多的 GAN 相關的論文發表。
1. 基本原理
GAN 的思想其實非常簡單,就是生成器網絡和判別器網絡的彼此博弈。
GAN 主要就是兩個網絡組成,生成器網絡(Generator)和判別器網絡(Discriminator),通過這兩個網絡的互相博弈,讓生成器網絡最終能夠學習到輸入數據的分布,這也就是 GAN 想達到的目的–學習輸入數據的分布。其基本結構如下圖所示,從下圖可以更好理解G 和 D 的功能,分別為:
- D 是判別器,負責對輸入的真實數據和由 G 生成的假數據進行判斷,其輸出是 0 和 1,即它本質上是一個二值分類器,目標就是對輸入為真實數據輸出是 1,對假數據的輸入,輸出是 0;
- G 是生成器,它接收的是一個隨機噪聲,并生成圖像。
在訓練的過程中,G 的目標是盡可能生成足夠真實的數據去迷惑 D,而 D 就是要將 G 生成的圖片都辨別出來,這樣兩者就是互相博弈,最終是要達到一個平衡,也就是納什均衡。
2. 優點
(以下優點和缺點主要來自 Ian Goodfellow 在 Quora 上的回答,以及知乎上的回答)
- GAN 模型只用到了反向傳播,而不需要馬爾科夫鏈
- 訓練時不需要對隱變量做推斷
- 理論上,只要是可微分函數都可以用于構建 D 和 G ,因為能夠與深度神經網絡結合做深度生成式模型
- G 的參數更新不是直接來自數據樣本,而是使用來自 D 的反向傳播
- 相比其他生成模型(VAE、玻爾茲曼機),可以生成更好的生成樣本
- GAN 是一種半監督學習模型,對訓練集不需要太多有標簽的數據;
- 沒有必要遵循任何種類的因子分解去設計模型,所有的生成器和鑒別器都可以正常工作
3. 缺點
- 可解釋性差,生成模型的分布 Pg(G)沒有顯式的表達
- 比較難訓練, D 與 G 之間需要很好的同步,例如 D 更新 k 次而 G 更新一次
- 訓練 GAN 需要達到納什均衡,有時候可以用梯度下降法做到,有時候做不到.我們還沒有找到很好的達到納什均衡的方法,所以訓練 GAN 相比 VAE 或者 PixelRNN 是不穩定的,但我認為在實踐中它還是比訓練玻爾茲曼機穩定的多.
- 它很難去學習生成離散的數據,就像文本
- 相比玻爾茲曼機,GANs 很難根據一個像素值去猜測另外一個像素值,GANs 天生就是做一件事的,那就是一次產生所有像素,你可以用 BiGAN 來修正這個特性,它能讓你像使用玻爾茲曼機一樣去使用 Gibbs 采樣來猜測缺失值
- 訓練不穩定,G 和 D 很難收斂;
- 訓練還會遭遇梯度消失、模式崩潰的問題
- 缺乏比較有效的直接可觀的評估模型生成效果的方法
3.1 為什么訓練會出現梯度消失和模式奔潰
GAN 的本質就是 G 和 D 互相博弈并最終達到一個納什平衡點,但這只是一個理想的情況,正常情況是容易出現一方強大另一方弱小,并且一旦這個關系形成,而沒有及時找到方法平衡,那么就會出現問題了。而梯度消失和模式奔潰其實就是這種情況下的兩個結果,分別對應 D 和 G 是強大的一方的結果。
首先對于梯度消失的情況是D 越好,G 的梯度消失越嚴重,因為 G 的梯度更新來自 D,而在訓練初始階段,G 的輸入是隨機生成的噪聲,肯定不會生成很好的圖片,D 會很容易就判斷出來真假樣本,也就是 D 的訓練幾乎沒有損失,也就沒有有效的梯度信息回傳給 G 讓 G 去優化自己。這樣的現象叫做 gradient vanishing,梯度消失問題。
其次,對于模式奔潰(mode collapse)問題,主要就是 G 比較強,導致 D 不能很好區分出真實圖片和 G 生成的假圖片,而如果此時 G 其實還不能完全生成足夠真實的圖片的時候,但 D 卻分辨不出來,并且給出了正確的評價,那么 G 就會認為這張圖片是正確的,接下來就繼續這么輸出這張或者這些圖片,然后 D 還是給出正確的評價,于是兩者就是這么相互欺騙,這樣 G 其實就只會輸出固定的一些圖片,導致的結果除了生成圖片不夠真實,還有就是多樣性不足的問題。
更詳細的解釋可以參考 令人拍案叫絕的Wasserstein GAN,這篇文章更詳細解釋了原始 GAN 的問題,主要就是出現在 loss 函數上。
3.2 為什么GAN不適合處理文本數據
3.3 為什么GAN中的優化器不常用SGD
對于鞍點,來自百度百科的解釋是:
鞍點(Saddle point)在微分方程中,沿著某一方向是穩定的,另一條方向是不穩定的奇點,叫做鞍點。在泛函中,既不是極大值點也不是極小值點的臨界點,叫做鞍點。在矩陣中,一個數在所在行中是最大值,在所在列中是最小值,則被稱為鞍點。在物理上要廣泛一些,指在一個方向是極大值,另一個方向是極小值的點。
鞍點和局部極小值點、局部極大值點的區別如下圖所示:
4. 訓練的技巧
訓練的技巧主要來自Tips and tricks to make GANs work。
1. 對輸入進行規范化
- 將輸入規范化到 -1 和 1 之間
- G 的輸出層采用Tanh激活函數
2. 采用修正的損失函數
在原始 GAN 論文中,損失函數 G 是要 min(log(1?D))min (log(1-D))min(log(1?D)), 但實際使用的時候是采用 max(logD)max(logD)max(logD),作者給出的原因是前者會導致梯度消失問題。
但實際上,即便是作者提出的這種實際應用的損失函數也是存在問題,即模式奔潰的問題,在接下來提出的 GAN 相關的論文中,就有不少論文是針對這個問題進行改進的,如 WGAN 模型就提出一種新的損失函數。
3. 從球體上采樣噪聲
- 不要采用均勻分布來采樣
- 從高斯分布中采樣得到隨機噪聲
- 當進行插值操作的時候,從大圓進行該操作,而不要直接從點 A 到 點 B 直線操作,如下圖所示
- 更多細節可以參考 Tom White’s 的論文 Sampling Generative Networks 以及代碼 https://github.com/dribnet/plat
4. BatchNorm
- 采用 mini-batch BatchNorm,要保證每個 mini-batch 都是同樣的真實圖片或者是生成圖片
- 不采用 BatchNorm 的時候,可以采用 instance normalization(對每個樣本的規范化操作)
- 可以使用虛擬批量歸一化(virtural batch normalization):開始訓練之前預定義一個 batch R,對每一個新的 batch X,都使用 R+X 的級聯來計算歸一化參數
5. 避免稀疏的梯度:Relus、MaxPool
- 稀疏梯度會影響 GAN 的穩定性
- 在 G 和 D 中采用 LeakyReLU 代替 Relu 激活函數
- 對于下采樣操作,可以采用平均池化(Average Pooling) 和 Conv2d+stride 的替代方案
- 對于上采樣操作,可以使用 PixelShuffle(https://arxiv.org/abs/1609.05158), ConvTranspose2d + stride
6. 標簽的使用
- 標簽平滑。也就是如果有兩個目標標簽,假設真實圖片標簽是 1,生成圖片標簽是 0,那么對每個輸入例子,如果是真實圖片,采用 0.7 到 1.2 之間的一個隨機數字來作為標簽,而不是 1;一般是采用單邊標簽平滑
- 在訓練 D 的時候,偶爾翻轉標簽
- 有標簽數據就盡量使用標簽
7. 使用 Adam 優化器
8. 盡早追蹤失敗的原因
- D 的 loss 變成 0,那么這就是訓練失敗了
- 檢查規范的梯度:如果超過 100,那出問題了
- 如果訓練正常,那么 D loss 有低方差并且隨著時間降低
- 如果 g loss 穩定下降,那么它是用糟糕的生成樣本欺騙了 D
9. 不要通過統計學來平衡 loss
10. 給輸入添加噪聲
- 給 D 的輸入添加人為的噪聲
- http://www.inference.vc/instance-noise-a-trick-for-stabilising-gan-training/
- https://openreview.net/forum?id=Hk4_qw5xe
- 給 G 的每層都添加高斯噪聲
11. 對于 Conditional GANs 的離散變量
- 使用一個 Embedding 層
- 對輸入圖片添加一個額外的通道
- 保持 embedding 低維并通過上采樣操作來匹配圖像的通道大小
12 在 G 的訓練和測試階段使用 Dropouts
- 以 dropout 的形式提供噪聲(50%的概率)
- 訓練和測試階段,在 G 的幾層使用
- https://arxiv.org/pdf/1611.07004v1.pdf
參考文章:
- Goodfellow et al., “Generative Adversarial Networks”. ICLR 2014.
- GAN系列學習(1)——前生今世
- 干貨 | 深入淺出 GAN·原理篇文字版(完整)
- 令人拍案叫絕的Wasserstein GAN
- 生成對抗網絡(GAN)相比傳統訓練方法有什么優勢?
- the-gan-zoo
- What-is-the-advantage-of-generative-adversarial-networks-compared-with-other-generative-models
- What-are-the-pros-and-cons-of-using-generative-adversarial-networks-a-type-of-neural-network-Could-they-be-applied-to-things-like-audio-waveform-via-RNN-Why-or-why-not
- Tips and tricks to make GANs work
注:配圖來自網絡和參考文章
以上就是本文的主要內容和總結,可以留言給出你對本文的建議和看法。
同時也歡迎關注我的微信公眾號–機器學習與計算機視覺或者掃描下方的二維碼,和我分享你的建議和看法,指正文章中可能存在的錯誤,大家一起交流,學習和進步!
總結
以上是生活随笔為你收集整理的[GAN学习系列] 初识GAN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 后端:Layui实现文件上传功能
- 下一篇: jersey创建restful服务及调用