深度学习之生成对抗网络(7)WGAN原理
深度學習之生成對抗網絡(7)WGAN原理
- 1. JS散度的缺陷
- 2. EM距離
- 3. WGAN-GP
?WGAN算法從理論層面分析了GAN訓練不穩定的原因,并提出了有效的解決方法。那么是什么原因導致了GAN訓練如此不穩定呢?WGAN提出是因為JS散度在不重疊的分布 ppp和 qqq上的梯度曲面是恒定為0的。如下圖所示。當分布p和q不重疊時,JS散度的梯度值始終為0,從而導致此時GAN的訓練出現梯度彌散現象,參數長時間得不到更新,網絡無法收斂。
圖1. JS散度出現梯度彌散現象
?接下來我們將詳細闡述JS散度的缺陷以及怎么解決此缺陷。
1. JS散度的缺陷
為了避免過多的理論推導,我們這里通過一個簡單的分布實例來解釋JS散度的缺陷。
考慮完全不重疊(θ≠0θ≠0θ?=0)的兩個分布ppp和qqq,其中ppp為:
?(x,y)∈p,x=0,y~U(0,1)?(x,y)∈p,x=0,y\sim\text{U}(0,1)?(x,y)∈p,x=0,y~U(0,1)
分布qqq為:
?(x,y)∈q,x=θ,y~U(0,1)?(x,y)∈q,x=θ,y\sim\text{U}(0,1)?(x,y)∈q,x=θ,y~U(0,1)
其中θ∈Rθ∈Rθ∈R,當θ=0θ=0θ=0時,分布ppp和qqq重疊,兩者相等;當θ≠0θ≠0θ?=0時,分布ppp和qqq不重疊。
?我們來分析上述分布ppp和qqq之間的JS散度隨θθθ的變化情況。根據KL散度與JS散度的定義,計算θ=0θ=0θ=0時的JS散度DJS(p∣∣q)D_{JS} (p||q)DJS?(p∣∣q):
DKL(p∣∣q)=∑x=0,y~U(0,1)1?log?10=+∞D_{KL} (p||q)=∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}?\frac{1}{0}=+∞DKL?(p∣∣q)=x=0,y~U(0,1)∑?1?log?01?=+∞
DKL(q∣∣p)=∑x=θ,y~U(0,1)1?log?10=+∞D_{KL} (q||p)=∑_{x=θ,y\sim\text{U}(0,1)}1\cdot\text{log}?\frac{1}{0}=+∞DKL?(q∣∣p)=x=θ,y~U(0,1)∑?1?log?01?=+∞
DJS(p∣∣q)=12(∑x=0,y~U(0,1)1?log11/2+∑x=0,y~U(0,1)1?log11/2)=log?2D_{JS} (p||q)=\frac{1}{2} \bigg(∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{1/2}+∑_{x=0,y\sim\text{U}(0,1)}1\cdot\text{log}\frac{1}{1/2}\bigg)=\text{log}?2DJS?(p∣∣q)=21?(x=0,y~U(0,1)∑?1?log1/21?+x=0,y~U(0,1)∑?1?log1/21?)=log?2
?當θ=0θ=0θ=0時,兩個分布完全重疊,此時的JS散度和KL散度都取得最小值,即0:
DKL(p∣∣q)=DKL(q∣∣p)=DJS(p∣∣q)=0D_{KL} (p||q)=D_{KL} (q||p)=D_{JS} (p||q)=0DKL?(p∣∣q)=DKL?(q∣∣p)=DJS?(p∣∣q)=0
從上面的推導,我們可以得到DJS(p∣∣q)D_{JS} (p||q)DJS?(p∣∣q)隨θθθ的變化趨勢:
DJS(p∣∣q)={log?2θ≠00θ=0D_{JS} (p||q) = \begin{cases} \text{log?}2 &\text{} θ≠0 \\ 0 &\text{} θ=0 \end{cases}DJS?(p∣∣q)={log?20?θ?=0θ=0?
也就是說,當兩個分布完全不重疊時,無論發布之間的距離遠近,JS散度為恒定值log?2\text{log}?2log?2,此時JS散度將無法產生有效的梯度信息;當兩個分布出現重疊時,JS散度采會平滑變動,產生有效梯度信息;當完全重合后,JS散度取得最小值0.如下圖所示,紅色的曲線分割兩個正態分布,由于兩個分布沒有重疊,生成樣本位置處的梯度值始終為0,無法更新生成網絡的參數,從而出現網絡訓練困難的現象。
?因此,JS散度在分布ppp和qqq不重疊時是無法平滑地衡量分布之間的距離,從而導致此位置上無法產生有效梯度信息,出現GAN訓練不穩定的情況。要解決此問題,需要使用一種更好的分布距離衡量標準,使得它即使在分布ppp和qqq不重疊時,也能平滑反映分布之間的真實距離變化。
2. EM距離
?WGAN論文發現了JS散度導致GAN訓練不穩定的問題,并引入了一種新的分布距離度量方法:Wasserstein距離,也叫推土機距離(Earth-Mover Distance,簡稱EM距離),它表示了從一個分布變換到另一個分布的最小代價,定義為:
W(p,q)=infγ~∏(p,q)E(x,y)~γ[∥x?y∥]W(p,q)=\underset{γ\sim∏(p,q)}{\text{inf}}\mathbb E_{(x,y)\simγ} [\|x-y\|]W(p,q)=γ~∏(p,q)inf?E(x,y)~γ?[∥x?y∥]
其中∏(p,q)∏(p,q)∏(p,q)是分布ppp和qqq組合起來的所有可能的聯合分布的集合,對于每個可能的聯合分布γ~∏(p,q)γ\sim∏(p,q)γ~∏(p,q),計算距離∥x?y∥\|x-y\|∥x?y∥的期望E(x,y)~γ[∥x?y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)~γ?[∥x?y∥],其中(x,y)(x,y)(x,y)采樣自聯合分布γγγ。不同的聯合分布γγγ由不同的期望E(x,y)~γ[∥x?y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)~γ?[∥x?y∥],這些期望中的下確界即定義為分布ppp和qqq的Wasserstein距離。其中inf?{?}\text{inf}?\{\cdot\}inf?{?}表示集合的下確界,例如{x∣1<x<3,x∈R}\{x|1<x<3,x∈R\}{x∣1<x<3,x∈R}的下確界為1。
?繼續考慮圖2中的例子,我們直接給出分布ppp和qqq之間的EM距離的表達式:
W(p,q)=∣θ∣W(p,q)=|θ|W(p,q)=∣θ∣
繪制出JS散度和EM距離的曲線,如下圖所示,可以看到,JS散度在θ=0θ=0θ=0處不連續,其他位置導數均為0,而EM距離總能夠產生有效的導數信息,因此EM距離相對于JS散度更適合直到GAN網絡的訓練。
3. WGAN-GP
?考慮到幾乎不可能遍歷所有的聯合分布γγγ去計算距離∥x?y∥\|x-y\|∥x?y∥的期望E(x,y)~γ[∥x?y∥]\mathbb E_{(x,y)\simγ} [\|x-y\|]E(x,y)~γ?[∥x?y∥],因此直接計算生成網絡分布pgp_gpg?與真實數據數據分布prp_rpr?的距離W(pr,pg)W(p_r,p_g )W(pr?,pg?)距離是不現實的,WGAN作者基于Kantorchovich-Rubin對偶性將直接求W(pr,pg)W(p_r,p_g )W(pr?,pg?)轉換為求:
W(pr,pg)=1Ksup∥f∥L≤KEx~pr[f(x)]?Ex~pg[f(x)]W(p_r,p_g )=\frac{1}{K} \underset{\|f\|_L≤K}{\text{sup}} \mathbb E_{x\sim p_r} [f(x)]-\mathbb E_{x\sim p_g} [f(x)]W(pr?,pg?)=K1?∥f∥L?≤Ksup?Ex~pr??[f(x)]?Ex~pg??[f(x)]
其中sup?{?}\text{sup}?\{\cdot\}sup?{?}表示集合的上確界,∥f∥L≤K\|f\|_L≤K∥f∥L?≤K表示函數f:R→Rf:R→Rf:R→R滿足K階-Lipschitz連續性,即滿足
∣f(x1)?f(x2)∣≤K?∣x1?x2∣|f(x_1 )-f(x_2)|≤K\cdot|x_1-x_2 |∣f(x1?)?f(x2?)∣≤K?∣x1??x2?∣
?于是,我們使用判別網絡Dθ(x)D_θ (\boldsymbol x)Dθ?(x)參數化f(x)f(\boldsymbol x)f(x)函數,在DθD_θDθ?滿足1階-Lipschitz約束條件下,即K=1K=1K=1,此時:
W(pr,pg)=1Ksup∥Dθ∥L≤KEx~pr[Dθ(x)]?Ex~pg[Dθ(x)]W(p_r,p_g )=\frac{1}{K} \underset{\|D_θ\|_L≤K}{\text{sup}} \mathbb E_{x\sim p_r} [D_θ (\boldsymbol x)]-\mathbb E_{x\sim p_g} [D_θ (\boldsymbol x)]W(pr?,pg?)=K1?∥Dθ?∥L?≤Ksup?Ex~pr??[Dθ?(x)]?Ex~pg??[Dθ?(x)]
因此求解W(pr,pg)W(p_r,p_g )W(pr?,pg?)的問題可以轉化為:
max?θEx~pr[Dθ(x)]?Ex~pg[Dθ(x)]\underset{θ}{\text{max}?}\ \mathbb E_{x\sim p_r} [D_θ (\boldsymbol x)]-\mathbb E_{x\sim p_g} [D_θ (\boldsymbol x)]θmax???Ex~pr??[Dθ?(x)]?Ex~pg??[Dθ?(x)]
這就是判別器D的優化目標。判別網絡函數D_θ (x)需要滿足1階-Lipschitz約束:
?x^D(x^)≤1?_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})≤1?x^?D(x^)≤1
?在WGAN-GP論文中,作者提出采用增加梯度懲罰項(Gradient Penalty)方法來迫使判別網絡滿足1階-Lipschitz函數約束,同時作者發現將梯度值約束在1周圍時工程效果更好,因此梯度懲罰項定義為:
GP?Ex^~Px^[(∥?x^D(x^)∥2?1)2]GP?\mathbb E_{\hat{\boldsymbol x}\sim P_{\hat{\boldsymbol x}}} [(\|?_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2-1)^2]GP?Ex^~Px^??[(∥?x^?D(x^)∥2??1)2]
因此WGAN的判別器D的訓練目標為:
maxθL(G,D)=Exr~pr[D(xr)]?Exf~pg[D(xf)]?EM距離?λEx^~Px^[(∥?x^D(x^)∥2?1)2]?GP懲罰項\underset{θ}{\text{max}} \mathcal L(G,D)=\underbrace{\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]}_{EM距離}-\underbrace{λ\mathbb E_{\hat{\boldsymbol x}\sim P_{\hat{\boldsymbol x}}} [(\|?_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2-1)^2]}_{GP懲罰項}θmax?L(G,D)=EM距離Exr?~pr??[D(xr?)]?Exf?~pg??[D(xf?)]???GP懲罰項λEx^~Px^??[(∥?x^?D(x^)∥2??1)2]??
其中x^\hat{\boldsymbol x}x^來自于xr\boldsymbol x_rxr?與xf\boldsymbol x_fxf?的線性差值:
x^=txr+(1?t)xf,t∈[0,1]\hat{\boldsymbol x}=t\boldsymbol x_r+(1-t) \boldsymbol x_f,t∈[0,1]x^=txr?+(1?t)xf?,t∈[0,1]
判別器D的優化目標是最小化上述的誤差L(G,D)\mathcal L(G,D)L(G,D),即迫使生成器G的分布pgp_gpg?與真實分布prp_rpr?之間的EM距離Exr~pr[D(xr)]?Exf~pg[D(xf)]\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]Exr?~pr??[D(xr?)]?Exf?~pg??[D(xf?)]項盡可能大,∥?x^D(x^)∥2\|?_{\hat{\boldsymbol x}} D(\hat{\boldsymbol x})\|_2∥?x^?D(x^)∥2?逼近于1。
?WGAN的生成器G的訓練目標為:
maxθL(G,D)=Exr~pr[D(xr)]?Exf~pg[D(xf)]?EM距離\underset{θ}{\text{max}} \mathcal L(G,D)=\underbrace{\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]}_{EM距離}θmax?L(G,D)=EM距離Exr?~pr??[D(xr?)]?Exf?~pg??[D(xf?)]??
即使得生成器的分布pgp_gpg?與真實分布prp_rpr?之間的EM距離越小越好??紤]到Exr~pr[D(xr)]\mathbb E_{\boldsymbol x_r\sim p_r } [D(\boldsymbol x_r )]Exr?~pr??[D(xr?)]一項與生成器無關,因此生成器的訓練目標簡寫為:
maxθL(G,D)=?Exf~pg[D(xf)]=?Ez~pz(?)[D(G(z))]\begin{aligned}\underset{θ}{\text{max}} \mathcal L(G,D)&=-E_{\boldsymbol x_f\sim p_g} [D(\boldsymbol x_f )]\\ &=-E_{\boldsymbol z\sim p_\boldsymbol z (\cdot)} [D(G(\boldsymbol z))]\end{aligned}θmax?L(G,D)?=?Exf?~pg??[D(xf?)]=?Ez~pz?(?)?[D(G(z))]?
?從現實來看,判別網絡D的輸出不需要添加Sigmoid激活函數,這是因為原始版本的判別器的功能是作為二分類網絡,添加Sigmoid函數獲得類別的概率;而WGAN中判別器作為EM距離的度量網絡,其目標是衡量生成網絡的分布pgp_gpg?和真實分布prp_rpr?之間的EM距離,屬于實數空間,因此不需要添加Sigmoid激活函數。在誤差函數計算時,WGAN也沒有log\text{log}log函數存在。在訓練WGAN時,WGAN作者推薦使用RMSProp或SGD等不帶動量的優化器。
?WGAN從理論層面發現了原始GAN容易出現訓練不穩定的原因,并給出了一種新的距離度量標準和工程實現解決方案,取得了較好的效果。WGAN還在一定程度上緩解了模式崩塌的問題,使用WGAN的模型不容易出現模式崩塌的現象。需要注意的是,WGAN一般并不能提升模型的生成效果,僅僅是保證了模型訓練的穩定性。當然,保證模型能夠穩定地訓練也是取得良好效果的前提。如圖5所示,原始版本的DCGAN在不使用BN層等設定時出現了訓練不穩定的現象,在同樣設定下,使用WGAN來訓練判別器可以避免此現象,如圖6所示。
圖6. 不帶BN層的WGAN生成效果 創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎
總結
以上是生活随笔為你收集整理的深度学习之生成对抗网络(7)WGAN原理的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 使用linux的dhclient命令动态
- 下一篇: 深度学习之生成对抗网络(8)WGAN-G