【Gans入门】Pytorch实现Gans代码详解【70+代码】
簡述
由于科技論文老師要求閱讀Gans論文并在網上找到類似的代碼來學習。
文章目錄
- 簡述
- 代碼來源
- 代碼含義概覽
- 代碼分段解釋
- 導入包:
- 設置參數:
- 給出標準數據:
- 構建模型:
- 構建優化器
- 迭代細節
- 畫圖
- 全部代碼:
- 參考并學習的鏈接
代碼來源
https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/406_GAN.py
代碼含義概覽
這個大致講講這個代碼實現了什么。
這個模型的輸入為:一些數據夾雜在x2x^2x2和2x2+12x^2+12x2+1這個兩個函數之間的一些數據。這個用線性函數的隨機生成來生成這個東西
輸出: 這是一個生成模型,生成模型的結果就是生成通過上面的輸入數據輸出這樣的數據來畫一條曲線
-
我們每次只取15個在x方向上等距的點。然后畫出這條曲線來。
經過學習之后,我們要求這個模型能自己畫出一條在其中的曲線來。 -
當然,由于我們設置的區間是有弧線的,即區間的概率上是有偏差的。經過足夠多的擬合,有較高的概率使得整個模型畫出來的曲線也是一個弧線。
代碼分段解釋
導入包:
import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt設置參數:
- LR_G:生成器的學習率
- LR_D:判別器的學習率
- N_IDEAS:生成器的啟發因子(就是生成器這個神經網絡的初始輸入層的節點數)
- ART_COMPONENTS:觀測節點–每次用于畫線的那些輸出點的數量
- BATCH_SIZE:其實是輸入數據的數量。
- PAINT_POINTS :就是把重復的那么多數據(將x區間等分為觀測節點數量等分的x節點)疊起來而已。這樣之后就直接代入就可以知道數據了。
給出標準數據:
這個函數,會給出特定規模的標準數據
- 先創建一個(BATCH_SIZE,1)規模的來自于(1,2)均勻分布的隨機數。
- 再用這個數據構建 a?x2+(a?1)a*x ^2 + (a - 1)a?x2+(a?1) 其中a來自于(1,2)(1,2)(1,2)的均勻分布。然后有BATCH_SIZE 個結果,所以,我們會在前面說到,這個參數表示輸入集合的大小
構建模型:
搭建神經網絡
- 這里搭建的神經網絡,只需要構建映射層就好了。
- 生成器模型:先通過一個線性函數構建一個從N_IDEAS到128的映射。再通過激活函數ReLU()函數來做一個映射。最后,再用一個線性函數搭建從128到觀測點的映射。(這些映射都是用矩陣乘法來實現的,所以,其實參數空間是三個不同的矩陣)
- 判別式模型:先通過一個觀測點的到128的模型。再通過一個ReLU激活函數。之后,再用一個線性函數使得從128到1維度。一維就是常數,再做一個sigmoid的激活函數映射到(0,1)(0,1)(0,1)空間。表示概率。
構建優化器
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D) opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)構建了兩個優化器。其實就是把對應模型的參數放進來了而已,之后,再設置一下學習率。
這里采用的是Adam模型來做優化。
迭代細節
其實這上面應該還有一些畫圖而加上的函數,但是對于模型不是很重要,這里就不看了。最后會有一個整體的模型。
for step in range(10000):明顯看出,使用了10000次的迭代。
- 先調用標準數據生成函數,生成標準數據。
- 再用pytorch的隨機數來生特定大小的生成器啟發因子。
- 之后,再把這個隨機數丟給生成器。
- 明顯,通過這樣的訓練,其實逐漸的訓練這個生成器模型,在隨機給輸入的情況下,漸漸掌握輸出正確的結果(個人感覺這里有提高的可能)
再把假畫和真畫都丟給判別式模型。給出一個概率來。
之后構建兩個模型的交叉熵,需要降低的損失函數
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1)) G_loss = torch.mean(torch.log(1. - prob_artist1))這個其實是根據論文中的公式給出的。
- 注意到,這里跟下面算法中給出的梯度是相同的。就是前面少了個系數,但是有沒系數,對于這個不影響的。
其實上面只是把整個模型搭建起來,其實都還沒有運行的。
真正運行的部分是下面這里
注意到,其實非常重復的。
- 第一步的zero_grad()函數:
原因:
In PyTorch, we need to set the gradients to zero before starting to do backpropragation because PyTorch accumulates the gradients on subsequent backward passes. This is convenient while training RNNs. So, the default action is to accumulate the gradients on every loss.backward() call.
在PyTorch中,我們需要設置這個梯度到0,在開始反向傳播的訓練之前,因為Pytorch會累積這個梯度在之后的反向傳播過程中。這是非常方便的當訓練RNNs的時候,所以默認就這么設置了。
Because of this, when you start your training loop, ideally you should zero out the gradients so that you do the parameter update correctly. Else the gradient would point in some other directions than the intended direction towards the minimum (or maximum, in case of maximization objectives).
由于這個,當你開始你的訓練循環的時候,比較聰明的一點就是先把這個梯度設置為0,以確保你的訓練的參數會是正確的。否則的話,這個梯度會指向一些其他地方(亂跑)
上面的解釋來自于stackoverflow
https://stackoverflow.com/questions/48001598/why-is-zero-grad-needed-for-optimization
- 第二步:反向傳播,這里設置保留整個圖的情況下。
- 第三步:.step() 其實這個函數才真正表示這個模型被訓練了。
畫圖
由于我們每次生成時候后,其實都是生成了一個BATCH_SIZE個。但是我們一次畫太多的圖的話,會顯得很丑,所以就只畫第一個圖就好了。
這里取模的原因就在于避免畫太多的圖,導致耗費太多資源。
if step % 500 == 0: # plottingplt.cla()plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting', )# 2x^2 + 1plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')# x^2plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(),fontdict={'size': 13})plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})plt.ylim((0, 3))plt.legend(loc='upper right', fontsize=10)plt.draw()plt.pause(0.01)全部代碼:
import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt# Hyper Parameters BATCH_SIZE = 64 LR_G = 0.0001 # learning rate for generator LR_D = 0.0001 # learning rate for discriminator N_IDEAS = 5 # think of this as number of ideas for generating an art work (Generator) ART_COMPONENTS = 15 # it could be total point G can draw in the canvas PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])def artist_works(): # painting from the famous artist (real target)a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]paintings = a * np.power(PAINT_POINTS, 2) + (a - 1)paintings = torch.from_numpy(paintings).float()return paintingsG = nn.Sequential( # Generatornn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas )D = nn.Sequential( # Discriminatornn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(), # tell the probability that the art work is made by artist )opt_D = torch.optim.Adam(D.parameters(), lr=LR_D) opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)plt.ion() # something about continuous plottingfor step in range(10000):artist_paintings = artist_works() # real painting from artistG_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideasG_paintings = G(G_ideas) # fake painting from G (random ideas)prob_artist0 = D(artist_paintings) # D try to increase this probprob_artist1 = D(G_paintings) # D try to reduce this probD_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))opt_D.zero_grad()D_loss.backward(retain_graph=True) # reusing computational graphopt_D.step()opt_G.zero_grad()G_loss.backward(retain_graph=True)opt_G.step()if step % 500 == 0: # plottingplt.cla()plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting', )# 2x^2 + 1plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')# x^2plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(),fontdict={'size': 13})plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})plt.ylim((0, 3))plt.legend(loc='upper right', fontsize=10)plt.draw()plt.pause(0.01)plt.ioff() plt.show()參考并學習的鏈接
- https://stackoverflow.com/questions/48001598/why-is-zero-grad-needed-for-optimization
- https://blog.csdn.net/cherrylvlei/article/details/53149381
- https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-optim/
- https://morvanzhou.github.io/tutorials/machine-learning/torch/4-06-GAN/
總結
以上是生活随笔為你收集整理的【Gans入门】Pytorch实现Gans代码详解【70+代码】的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: VS2017调用MySQL 8.0(附上
- 下一篇: 【论文阅读】Triple GANs论文阅