重读经典:《Generative Adversarial Nets》
GAN論文逐段精讀【論文精讀】
這是李沐博士論文精讀的第五篇論文,這次精讀的論文是 GAN。目前谷歌學術顯示其被引用數已經達到了37000+。GAN 應該是機器學習過去五年上頭條次數最多的工作,例如抖音里面生成人物卡通頭像,人臉互換以及自動駕駛中通過傳感器采集的數據生成逼真的圖像數據,用于仿真測試等。這里李沐博士講解的論文是 NeurIPS 版,與 arXiv 版稍有不同。
GAN 論文鏈接:https://proceedings.neurips.cc/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf
1. 標題、作者、摘要
首先是論文標題,GAN 就取自于論文標題首字母,論文標題中文意思是:生成式對抗網絡。機器學習里面有兩大類模型:一種是分辨模型,例如 AlexNet、ResNet 對數據進行分類或預測一個實數值、另一種就是生成模型,用于生成數據本身。Adversarial 是對抗的意思,第一次讀的時候可能不知道什么意思,先放在這里,接著往下讀。最后是 Nets,網絡的意思,不過建議大家還是寫成 Networks 比較規范一些。
下面是論文作者,一作大家很熟悉了,他的另一個代表作就是深度學習經典書籍(花書):《深度學習》,通信作者是深度學習三巨頭之一,2018年圖靈獎的獲得者。
這里有一個小八卦,當時一作在給論文取標題時,有人說 GAN 這個詞在中文里寫作干,和英語里的 fxxk 意思很接近,但是意義上豐富多了,一作就說這個好,就用它了。
下面是論文摘要,摘要總共七句話。
- 前三句話介紹我們提出了一個新的 framework, 通過對抗過程估計生成模型;我們同時會訓練兩個模型,一個是生成模型 GGG,生成模型用來捕獲數據的分布,另一個模型是辨別模型 DDD,辨別模型用來判斷樣本是來自于訓練數據還是生成模型生成的。生成模型 GGG 的訓練過程是使辨別模型犯錯概率最大化實現的,當辨別模型犯錯概率越大,則生成模型生成的數據越接近于真實數據。整個framework類似于博弈論里的二人對抗游戲。
- 第四句話是說,在任意函數空間里,存在唯一解,GGG 能找出訓練數據的真實分布,而 DDD 的預測概率為 12\frac{1}{2}21?,此時辨別模型已經分辨不出樣本的來源。
- 最后就是說生成模型和辨別模型可以通過反向傳播進行訓練,實驗也顯示了提出的框架潛能。
2. 導言、相關工作
下面是 Introduction 部分,總共3段。
- 第一段說深度學習在判別模型取得了很大的成功,但是在生成模型進展還很緩慢,主要原因是在最大似然估計時會遇到很多棘手的近似概率計算,因此作者提出一個新的生成模型來解決這些問題。
- 第二段作者舉了一個例子來解釋對抗網絡。生成模型好比是一個造假者,而判別模型好比是警察,警察需要能區分真幣和假幣,而造假者需要不斷改進技術使警察不能區分真幣和假幣。
- 第三段說生成模型可以通過多層感知機來實現,輸入為一些隨機噪聲,可以通過反向傳播來訓練。
然后是相關工作部分,這里有件有趣的事。當時GAN作者在投稿時,Jürgen Schmidhuber 恰好是論文審稿者,Jürgen Schmidhuber 就質問:“你這篇論文和我的 PM 論文很相似,只是方向相反了,應該叫 Inverse PM 才對”。然后Ian就在郵件中回復了,但是兩人還在爭論。
一直到NIPS2016大會,Ian 的 GAN Tutorial上,發生了尷尬的一幕。Jürgen Schmidhuber 站起來提問后,先講自己在1992年提出了一個叫做 Predictability Minimization 的模型,它如何如何,一個網絡干嘛另一個網絡干嘛,接著話鋒一轉,直問臺上的Ian:“你覺得我這個 PM 模型跟你的 GAN 有沒有什么相似之處啊?” 似乎只是一個很正常的問題,可是 Ian 聽完后反應卻很激烈。Ian 表示:“Schmidhuber 已經不是第一次問我這個問題了,之前我和他就已經通過郵件私下交鋒了幾回,所以現在的情況純粹就是要來跟我公開當面對質,順便浪費現場幾百號人聽tutorial 的時間。然后你問我 PM 模型和 GAN 模型有什么相似之處,我早就公開回應過你了,不在別的地方,就在我當年的論文中,而且后來的郵件也已經把我的意思說得很清楚了,還有什么可問的呢?”
關于Jürgen Schmidhuber 和 Ian之間爭論的更多趣事可以看這篇文章:從PM到GAN——LSTM之父Schmidhuber橫跨22年的怨念。
3. 模型、理論
下面開始介紹 Adversarial nets。為了學習生成器在數據 x\boldsymbol{x}x 上的分布 pgp_gpg?,我們定義輸入噪聲變量 pz(z)p_{\boldsymbol{z}}({\boldsymbol{z}})pz?(z),數據空間的映射用 G(z;θg)G(\boldsymbol{z};\theta_g)G(z;θg?) 表示,其中 GGG 是一個可微分函數(多層感知機),其參數為 θg\theta_gθg?。我們再定義第二個多層感知機 D(x;θd)D(\boldsymbol{x};\theta_d)D(x;θd?),其輸出為標量。D(x)D(\boldsymbol{x})D(x) 表示數據 x\boldsymbol{x}x 來自真實數據的概率。
下面是訓練策略,我們同時訓練生成模型 GGG 和判別模型 DDD。對于判別模型 DDD,我們通過最大化將正確標簽分配給訓練樣本和生成器生成樣本的概率來訓練;對于生成模型 GGG,我們通過最小化 log?(1?D(G(z)))\log (1-D(G(\boldsymbol{z})))log(1?D(G(z))) 來訓練,總結為:
- D(x)D(\boldsymbol{x})D(x) 概率越大,判別器訓練越好,log?D(x)\log D(\boldsymbol{x})logD(x) 越大;
- D(G(z))D(G(\boldsymbol{z}))D(G(z)) 概率越小,判別器訓練越好,log?(1?D(G(z)))\log (1-D(G(\boldsymbol{z})))log(1?D(G(z))) 越大;
- D(G(z))D(G(\boldsymbol{z}))D(G(z)) 概率越大,生成器訓練越好,log?(1?D(G(z)))\log (1-D(G(\boldsymbol{z})))log(1?D(G(z))) 越小;
下圖是對抗網絡訓練的直觀示意圖,黑色曲線是真實樣本,綠色曲線為生成樣本,藍色曲線為判別概率。可以看到在 (a) 階段,真實樣本和生成樣本分布不一致,此時判別器能夠正確區分真實樣本和生成樣本。到 (d) 階段,真實樣本和生成樣本分布幾乎一致,此時判別器很難再區分二者,此時判別器輸出概率為 12\frac{1}{2}21?。
算法1是整個對抗網絡的正式描述,對于判別器,我們通過梯度上升來訓練;對于生成器,我們通過梯度下降來訓練。
在實際訓練時,公式(1)往往不能提供足夠的梯度讓生成器去學習。因為在學習的早期階段,生成器 GGG 性能很差,判別器 DDD 有著很高的置信度判別數據來源。在這種情況,log?(1?D(G(z)))\log (1-D(G(\boldsymbol{z})))log(1?D(G(z))) 存在飽和現象。因此在這個時候,我們通過最大化 log?D(G(z))\log D(G(\boldsymbol{z}))logD(G(z)) 來訓練生成器 GGG。
下面是 Theoretical Results,對于任意給定的生成器 GGG,則最優的判別器 DDD 為:
DG?(x)=pdata?(x)pdata?(x)+pg(x)D_{G}^{*}(\boldsymbol{x})=\frac{p_{\text {data }}(\boldsymbol{x})}{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})} DG??(x)=pdata??(x)+pg?(x)pdata??(x)?
下面是證明過程,對于給定的生成器 GGG,判別器 DDD 通過最大化期望 V(G,D)V(G,D)V(G,D) 來訓練, V(G,D)V(G,D)V(G,D) 為:
V(G,D)=∫xpdata?(x)log?(D(x))dx+∫zpz(z)log?(1?D(g(z)))dz=∫xpdata?(x)log?(D(x))+pg(x)log?(1?D(x))dx\begin{aligned} V(G, D) &=\int_{\boldsymbol{x}} p_{\text {data }}(\boldsymbol{x}) \log (D(\boldsymbol{x})) d x+\int_{\boldsymbol{z}} p_{\boldsymbol{z}}(\boldsymbol{z}) \log (1-D(g(\boldsymbol{z}))) d z \\ &=\int_{\boldsymbol{x}} p_{\text {data }}(\boldsymbol{x}) \log (D(\boldsymbol{x}))+p_{g}(\boldsymbol{x}) \log (1-D(\boldsymbol{x})) d x \end{aligned} V(G,D)?=∫x?pdata??(x)log(D(x))dx+∫z?pz?(z)log(1?D(g(z)))dz=∫x?pdata??(x)log(D(x))+pg?(x)log(1?D(x))dx?
已知 (a,b)∈R2(a, b) \in \mathbb{R}^{2}(a,b)∈R2,函數 y→alog?(y)+blog?(1?y)y \rightarrow a \log (y)+b \log (1-y)y→alog(y)+blog(1?y) 在 aa+b\frac{a}{a+b}a+ba? 處取得最大值。
根據上面的證明,在最優判別器處,則有最大期望值 ?log?4-\log4?log4。
最后簡單總結下,雖然在本文中,作者做的實驗現在來看比較簡單,但是整個工作是一個開創性的工作,GAN 屬于無監督學習研究,而且作者是使用有監督學習的損失函數去訓練無監督學習;而且本文的寫作也是教科書級別的寫作,作者的寫作是很明確的,讀者只看這一篇文章就能對GAN有足夠的了解,不需要再去看其它更多的文獻。
總結
以上是生活随笔為你收集整理的重读经典:《Generative Adversarial Nets》的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pcclient.exe是什么进程 pc
- 下一篇: 交行信用卡取现手续费和利息 这样算清清楚