python开源库生成式对抗网络_Python:使用Tensorflow开发一维生成对抗网络
生成式對抗網(wǎng)絡(luò)是一種用于訓(xùn)練生成器模型的深度學(xué)習(xí)體系結(jié)構(gòu)。GAN由兩個模型組成,一個稱為生成器(Generator),另一個稱為判別器(Discriminator)。顧名思義,生成器生成新樣本,判別器負(fù)責(zé)對生成的樣本進(jìn)行真?zhèn)畏诸悺?/p>
GAN實(shí)際如何運(yùn)作的?
判別器模型的性能用于更新生成器和判別器本身的網(wǎng)絡(luò)權(quán)重。生成器實(shí)際上從未看到過數(shù)據(jù),而是根據(jù)判別器的性能不斷地進(jìn)行調(diào)整,更具體地說,是根據(jù)從判別器傳回來的誤差梯度進(jìn)行調(diào)整。生成器逐漸學(xué)會通過產(chǎn)生與真實(shí)樣本完全相同的樣本來欺騙判別器。
在這篇文章中,我們將選擇一個簡單的一維函數(shù)來直觀地理解GAN。本文分為5個部分:
選擇一個一維函數(shù)實(shí)現(xiàn)判別器模型實(shí)現(xiàn)生成器模型訓(xùn)練GAN模型性能評估1.一維函數(shù)
我們需要選擇一個一維函數(shù)來制作模型。一維函數(shù)的形式為
y = f(x),其中x是輸入,y是對應(yīng)的輸出
為簡單起見,我將使用函數(shù)y = x。您可以自由選擇任何函數(shù)。我們將保持輸入在-0.5和+0.5之間。下面給出了一個計算輸入的簡單函數(shù) :
該函數(shù)簡單地接受N個隨機(jī)值,并將每個值減去0.5,以便將輸入范圍保持在-0.5和+0.5之間。當(dāng)為real時y=1,當(dāng)為fake時y=0。
2.判別器模型
判別器只是一個簡單的分類模型,它可以預(yù)測樣本是real還是fake。判別器將兩個實(shí)數(shù)值的樣本作為輸入,并輸出樣本是real還是fake。我們處理的問題非常簡單,所以我們不需要非常復(fù)雜的神經(jīng)網(wǎng)絡(luò),我們將只采用一個隱藏層,其中有25個節(jié)點(diǎn)。您可以自由地試驗(yàn)節(jié)點(diǎn)數(shù)或?qū)訑?shù),以提高生成器的準(zhǔn)確性。我們將對隱藏層使用ReLu激活,對輸出層使用sigmoid激活。Python實(shí)現(xiàn)如下:
3.生成器模型
對于生成器,我們將噪聲輸入提供給生成器,此噪聲輸入也稱為潛在變量。
潛在變量是潛在空間中的隱藏變量或未觀察到的變量,潛在空間是這些變量的多維空間。
直到我們的生成器受到訓(xùn)練并賦予這些點(diǎn)意義,該潛在空間才有意義,這些點(diǎn)被映射到判別器的輸入。我們將定義一個3維的潛在空間(可以更改維數(shù)),并實(shí)驗(yàn)生成器的行為和準(zhǔn)確度如何變化。我們將對潛在空間中的每個變量使用高斯分布。生成器使用一個隱藏層,該隱藏層將由15個具有ReLu激活函數(shù)的神經(jīng)元組成。輸出層將由兩個神經(jīng)元組成,這兩個神經(jīng)元將連接到判別器層的輸入。
4.訓(xùn)練GAN模型
訓(xùn)練GAN模型的方法有很多,最簡單的方法是創(chuàng)建一個新的模型,該模型由生成器和判別器兩部分組成。我們只是在邏輯上封裝了生成器和判別器網(wǎng)絡(luò)。我們將把GAN模型作為一個整體進(jìn)行訓(xùn)練,這樣來自判別器的反向傳播誤差也會更新生成器的權(quán)重。如果判別器能夠很好地進(jìn)行分類,那么生成器的權(quán)重將更新得更多;如果判別器不能很好地進(jìn)行分類,那么生成器的權(quán)重將更新得少一些。這樣,在生成器和判別器之間就形成了一種對抗關(guān)系。Python實(shí)現(xiàn)代碼如下:
判別器模型的可訓(xùn)練屬性被設(shè)置為false,這樣就可以僅對standalone模型進(jìn)行訓(xùn)練。
現(xiàn)在我們只剩下對GAN模型進(jìn)行整體訓(xùn)練了。我們將編寫一個函數(shù)來做這個的事情。該函數(shù)將運(yùn)行10000個epochs,每運(yùn)行2000個epochs,它將評估判別器和生成器的性能。Python實(shí)現(xiàn)的代碼如下:
5.評估性能
在每隔一定的epochs之后,我們將調(diào)用show_performance函數(shù),該函數(shù)將從生成器中獲取真實(shí)樣本和虛假樣本并預(yù)測結(jié)果。我們還將在散點(diǎn)圖上繪制結(jié)果,以便我們可以查看GAN的性能。Python實(shí)現(xiàn)的代碼如下:
在epoch = 2000之后,我們得到了散點(diǎn)圖如下,您的圖可能會有所不同。
紅點(diǎn)表示real點(diǎn),藍(lán)點(diǎn)表示生成器生成的點(diǎn)。我們可以看到,藍(lán)點(diǎn)已開始呈y =x的形狀。
如果我們繼續(xù)進(jìn)行10000個epochs,您將得到類似下面的圖像。您可以嘗試使用更多個epochs(例如15000或20000個epochs)來獲得更好的準(zhǔn)確性。
現(xiàn)在我們可以看到,我們已經(jīng)從生成器中得到了一個更確定的樣本,我們可以說生成器已經(jīng)學(xué)習(xí)并擬合了這個函數(shù)。也就是說,僅僅通過誤差梯度,生成器就學(xué)會了這個函數(shù)。
總結(jié)
以上是生活随笔為你收集整理的python开源库生成式对抗网络_Python:使用Tensorflow开发一维生成对抗网络的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 笔记:Zygote和SystemServ
- 下一篇: 02 检索数据