CVPR 2021 | 如何让GAN的训练更加高效优雅
導讀
近年來,生成對抗技術在諸多圖像任務中得到運用,包括圖像編輯和生成、風格遷移和轉化、圖文描述生成、少樣本數據增強、圖像攻防對抗以及 AI 字體設計等。圖像生成對抗雖然取得不少成功運用案例,但其訓練效率對規模化日常迭代是個挑戰。為此,阿里媽媽搜索廣告團隊?聯合?浙江大學宋明黎教授的視覺智能與模式分析團隊?對此項工作開展了探索性研究,并提出了一種單階段生成對抗訓練方法(OSGAN, Training Generative Adversarial Networks in One Stage)來提升傳統 GAN 任務的訓練效率。實測該方法比傳統兩階段訓練方法實現了1.5倍的訓練加速。該項工作論文已被 CVPR 2021錄用,并已開源,歡迎交流討論。
論文下載:https://arxiv.org/abs/2103.00430?
開源項目:https://github.com/zju-vipa/OSGAN
背景
在諸多神經網絡中,生成對抗網絡(Generative Adversarial Network,GAN)的訓練方式和其他神經網絡訓練存在較大的區別:傳統 CNN 任務中,網絡各部分都按照最小化目標函數的方向進行優化,而 GAN 中生成器(generator)和判別器(discriminator)則是朝著相反的方向進行優化,以形成對抗。為此,當時Ian J. Goodfellow [1]采用了對生成器和判別器進行交替優化的方案,歸納為兩階段訓練方式(Two-Stage GAN,如下圖1所示)。顯然,這種兩階段訓練方式引入了不少重復的計算量,使得GAN的訓練效率通常低于其他神經網絡。
圖1針對上述 GAN 訓練效率低的問題,以下簡短介紹下我們最近發表在 CVPR 2021上的工作:Training Generative Adversarial Networks in One Stage。文章對該問題進行了深入的研究探索,即如何在一次訓練迭代中,同時完成對生成器和判別器的更新以消除Two-Stage GAN訓練中存在的冗余計算。文章中同樣將現有的 GAN 分為對稱GAN(Symmetric GAN)和 非對稱GAN(Asymmetric GAN)兩大類[2],并著重研究了如何對更加復雜的非對稱GAN進行了單階段訓練的問題。最后,對 One-Stage GAN 相對于 Two-Stage GAN 的加速比進行了對應分析。
方案
我們先簡單回顧下目前主流的兩類生成對抗訓練方法:
對稱GAN vs 非對稱GAN
GAN通過引入判別器網絡和生成器網絡之間最大化最小化的博弈過程[1],使得生成網絡實現了真實性樣本合成。GAN所采用的目標函數如下:
為了后續討論的方便,我們將上述目標函數拆解成分別針對判別器 和生成器的損失函數,其形式如下:
其中,和包含了關于 的相同對抗損失項:。因此,將這種生成對抗網絡稱為?對稱型GAN?。為了緩解訓練過程中生成器網絡梯度消失的問題,有學者又提出了非飽和對抗損失函數[3]:
其中,生成器網絡和判別器網絡關于的對抗損失項并不一致: vs 。將上述生成對抗網絡稱為?非對稱型GAN?。
如上述公式所述,GAN目標函數中的對抗項通常是關于的。因此,為了分析的方便性,將一般化的分成兩個部分:關于真實樣本的損失項和關于假樣本的損失項,其中。最終生成對抗網絡一般化的目標函數可以表示為:
對于對稱型生成對抗網絡來說,其對抗損失項滿足:;而對于非對稱型生成對抗網絡來說,其對抗損失項滿足:。
對稱型GAN的單階段訓練
對于對稱型GAN,和包含相同的關于假樣本的損失項:。其關于的梯度可以分別表示為和。通過在上乘以來得到。因此,可以利用在訓練判別器期間得到的進一步計算得到關于生成器參數的梯度,從而在更新實現的同時,訓練生成器。綜上所述,上述方法可以將對稱型 GAN 的兩階段訓練過程簡化成單階段過程。
非對稱型GAN的單階段訓練
對于非對稱型GAN,由于,梯度無法直接像對稱型GAN一樣從中獲取到。一種直接的思路是將 和 整合成一個損失函數,例如,從而我們可以從 中獲取得到 。因為GAN的對抗特性,和的符號通常是相反的。因此,我們采用以下損失函數整合方式:
而不采用 ,以避免和之間因為符號相反而產生的梯度抵消。設,以合并關于假樣本的損失項。然而,這種方式會產生另外一個問題:如何從混合的梯度中恢復出。
為了解決上述問題,我們對判別器網絡的反向傳播進行了調研,發現了主流神經網絡模塊中有一個很有意思的反向傳播性質。除了批歸一化模塊,其他模塊對應的損失函數 關于輸入的梯度和關于輸出的梯度之間的關系可以表示為:
其中 和 是由對應神經模塊或者其輸出決定的矩陣; 是一個滿足如下關系的函數:
上述梯度滿足公式的神經網絡模塊主要包括卷積模塊、全連接模塊以及非線性激活函數、以及池化模塊等。需要說明的是,雖然非線性激活函數和池化模塊是非線性操作,但是其關于輸入的梯度和輸出的梯度仍然滿足上述公式。
根據上述梯度公式,我們可以得到判別器網絡中 關于假樣本的梯度和 關于的梯度之間的關系:
其中表示判別器網絡的層數;是關于 的樣本比例標量;表示判別器 第層的特征。
關于樣本 的 對于判別器不同的網絡層是一個常數。同時對于每一個樣本都有一個對應的 。也就是說,一般不同樣本 有不同的 。一方面,我們只需要計算最后一層的 ,即 ,就可以得到所有網絡層的 值。這個方式只需要計算兩個標量之間的比值,計算過程十分簡單高效。另外一方面,的值根據樣本的不同而發生變化,因此,對于每一個樣本都要重新計算一次。由于是兩個標量:和 的比值,因此其計算代價在整個網絡訓練中可以忽略不計。
結合 和上述公式,我們可以按比例從混合的梯度:分解得到和,具體如下:
也就是說,我們可以通過對進行尺度縮放得到得到和。為了方便計算, 我們將上面這種尺度縮放操作應用到了損失函數上,以實現和上述梯度分解公式相同的效果。因此,我們可以得到判別器網絡和的實例損失函數:
其中和包含相同的損失項。通過這種方式,非對稱型GAN可以轉化為對稱型GAN。因此,可以將對稱型GAN中采用的單階段訓練策略應用到非對稱型GAN中。
實驗分析
我們進一步分析了單階段生成對抗網絡和兩階段生成對抗網絡的效率。主要從三個角度對這個問題進行了分析:1)在一個數據批訓練中,真實樣本的耗時和生成樣本的耗時;2)前向推理的耗時和反向傳播的耗時;3)關于網絡參數梯度計算的耗時和反向傳播的耗時。
圖2經過如上圖2所示的分析發現,普通 GAN 訓練的兩階段耗時分別為:通過和兩階段生成對抗網絡相同的方式,本文計算得到單階段生成對抗網絡的總耗時為:
最終,我們得到了在最壞的情況下單階段生成對抗網絡相對于兩階段生成對抗網絡的加速比:
圖3如上圖3所示,在所有實驗的效果達到穩定情況下,單階段對稱型 DCGAN 比兩階段快接近1.7倍,單階段非對稱 DCGAN 也比兩階段快1.6倍,更多性能數據參考文章 Efficiency Analysis 章節。
總結與展望
針對生成對抗技術在實際任務中訓練周期長的問題,我們提出了一種單階段的訓練方法 OSGAN,該方法在相同效果下相對兩階段方法性能提升1.5倍。同時,我們運用 OSGAN 對少量標注樣本進行數據增廣,在拍立淘廣告場景中,對少樣本有監督 CNN 任務帶來效果提升。未來我們將進一步探索更高效的 GAN 訓練和運用方法,期待可以在更多領域和落地中得到拓展和應用。
參考文獻
[1] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Advances in Neural Information Processing Systems (NeurIPS), pages 2672–2680, 2014.
[2] Li Liu, Wanli Ouyang, Xiaogang Wang, Paul Fieguth, Jie Chen, Xinwang Liu, and Matti Pietika ?inen. Deep learning for generic object detection: A survey. International Journal of Computer Vision (IJCV), 128(2):261–318, 2020.
[3] Martin Arjovsky and Le ?on Bottou. Towards principled methods for training generative adversarial networks. In International Conference on Learning Representations (ICLR), 2017.
END
???關于我們
阿里媽媽多模態搜索廣告算法團隊負責多模態搜索場景(拍立淘和找相似等)的商業化變現技術,歡迎對“計算機視覺/多模態自監督學習/搜索推薦廣告”感興趣的同學加入我們。投遞簡歷:alimama_tech@service.alibaba.com?,或點擊下方 [閱讀原文] 了解崗位詳情~
瘋狂暗示↓↓↓↓↓↓↓
總結
以上是生活随笔為你收集整理的CVPR 2021 | 如何让GAN的训练更加高效优雅的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 阿里妈妈应用系统大规模异步交互治理方案
- 下一篇: 【阿里妈妈数据科学系列】第一篇:认识在线