生活随笔
收集整理的這篇文章主要介紹了
WGAN-GP 学习笔记
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
今天看到paperweekly上有人分享了一個WGAN-GP的實現,是以MNIST為數據集,代碼簡潔,結構清晰。我最近也在看GAN的相關內容,就下載下來做個參考。?
代碼地址:https://github.com/bojone/gan/
對于這個基于tensorflow實現的代碼,我對其進行了簡單的注釋
?
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import numpy as np
from scipy import misc,ndimage#讀入本地的MNIST數據集,該函數為mnist專用
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)batch_size = 100 #每個batch的大小
width,height = 28,28 #每張圖片包含28*28個像素點
mnist_dim = width*height #用一個數字數組表示一張圖,那么這個數組展開成向量的長度就是28*28=784
random_dim = 10 #每張圖表示一個數字,從0到9
epochs = 1000000 #共100萬輪def my_init(size): #從[-0.05,0.05]的均勻分布中采樣得到維度是size的輸出return tf.random_uniform(size, -0.05, 0.05)#判別器相關參數設定
D_W1 = tf.Variable(my_init([mnist_dim, 128])) #784*128
D_b1 = tf.Variable(tf.zeros([128])) #長度為128的一維張量,值均為0
D_W2 = tf.Variable(my_init([128, 32]))
D_b2 = tf.Variable(tf.zeros([32]))
D_W3 = tf.Variable(my_init([32, 1]))
D_b3 = tf.Variable(tf.zeros([1]))
D_variables = [D_W1, D_b1, D_W2, D_b2, D_W3, D_b3]#生成器相關參數設定
G_W1 = tf.Variable(my_init([random_dim, 32]))
G_b1 = tf.Variable(tf.zeros([32]))
G_W2 = tf.Variable(my_init([32, 128]))
G_b2 = tf.Variable(tf.zeros([128]))
G_W3 = tf.Variable(my_init([128, mnist_dim]))
G_b3 = tf.Variable(tf.zeros([mnist_dim]))
G_variables = [G_W1, G_b1, G_W2, G_b2, G_W3, G_b3]#判別器網絡結構
def D(X):X = tf.nn.relu(tf.matmul(X, D_W1) + D_b1) #X的維度是100*784,D_W1維度是784*128,得到結果維度為100*128X = tf.nn.relu(tf.matmul(X, D_W2) + D_b2) #X的維度是100*128,D_W2維度是128*32,得到結果維度為100*32X = tf.matmul(X, D_W3) + D_b3 #X的維度是100*32,D_W3維度是32*1,得到結果維度為100*1return X#生成器網絡結構
def G(X):X = tf.nn.relu(tf.matmul(X, G_W1) + G_b1) #X的維度是100*10,G_W1維度是10*32,得到結果維度為100*32X = tf.nn.relu(tf.matmul(X, G_W2) + G_b2) #X的維度是100*32,G_W2維度是32*128,得到結果維度為100*128X = tf.nn.sigmoid(tf.matmul(X, G_W3) + G_b3) #X的維度是100*128,G_W3維度是128*784,得到結果維度為100*784return X#real_X是真實樣本,random_X是噪音數據,random_Y是生成器生成的偽樣本
real_X = tf.placeholder(tf.float32, shape=[batch_size, mnist_dim])
random_X = tf.placeholder(tf.float32, shape=[batch_size, random_dim])
random_Y = G(random_X)#求懲罰項,這個這個懲罰是“軟約束”,最終的結果不一定滿足這個約束,但是會在約束上下波動。這里Lipschitz約束的C=1
eps = tf.random_uniform([batch_size, 1], minval=0., maxval=1.) #eps是U[0,1]的隨機數
X_inter = eps*real_X + (1. - eps)*random_Y #在真實樣本和生成樣本之間隨機插值,希望這個約束可以“布滿”真實樣本和生成樣本之間的空間
grad = tf.gradients(D(X_inter), [X_inter])[0] #求梯度
grad_norm = tf.sqrt(tf.reduce_sum((grad)**2, axis=1)) #求梯度的二范數
grad_pen = 10 * tf.reduce_mean(tf.nn.relu(grad_norm - 1.)) #Lipschitz限制是要求判別器的梯度不超過K,這個loss項是希望判別器的梯度離K(此處K設為1)越近越好#判別器和生成器的損失函數
D_loss = tf.reduce_mean(D(real_X)) - tf.reduce_mean(D(random_Y)) + grad_pen
G_loss = tf.reduce_mean(D(random_Y)) #越接近真實樣本越好#判別器和生成器的優化函數
D_solver = tf.train.AdamOptimizer(1e-4, 0.5).minimize(D_loss, var_list=D_variables)
G_solver = tf.train.AdamOptimizer(1e-4, 0.5).minimize(G_loss, var_list=G_variables)#創建對話,初始化所有變量
sess = tf.Session()
sess.run(tf.global_variables_initializer())#是否存在“out”文件夾,不存在的話新建一個,存放實驗結果
if not os.path.exists('out/'):os.makedirs('out/')for e in range(epochs):for i in range(5): #每輪計算5個batchreal_batch_X,_ = mnist.train.next_batch(batch_size) #隨機抓取訓練數據中的100個批處理數據點random_batch_X = np.random.uniform(-1, 1, (batch_size, random_dim)) #從均勻分布中采樣,輸出100*10個樣本_,D_loss_ = sess.run([D_solver,D_loss], feed_dict={real_X:real_batch_X, random_X:random_batch_X})random_batch_X = np.random.uniform(-1, 1, (batch_size, random_dim))_,G_loss_ = sess.run([G_solver,G_loss], feed_dict={random_X:random_batch_X})#每1000輪輸出一次當前結果if e % 1000 == 0:print 'epoch %s, D_loss: %s, G_loss: %s'%(e, D_loss_, G_loss_)n_rows = 6check_imgs = sess.run(random_Y, feed_dict={random_X:random_batch_X}).reshape((batch_size, width, height))[:n_rows*n_rows] #由生成器得到偽樣本,維度為100*784,reshape為100個28*28的矩陣,取6*6個矩陣構成一張圖imgs = np.ones((width*n_rows+5*n_rows+5, height*n_rows+5*n_rows+5)) #203*203的值為1的二維矩陣for i in range(n_rows*n_rows):imgs[5+5*(i%n_rows)+width*(i%n_rows):5+5*(i%n_rows)+width+width*(i%n_rows), 5+5*(i/n_rows)+height*(i/n_rows):5+5*(i/n_rows)+height+height*(i/n_rows)] = check_imgs[i]misc.imsave('out/%s.png'%(e/1000), imgs)
轉自?https://blog.csdn.net/qq_20943513/article/details/73129308
總結
以上是生活随笔為你收集整理的WGAN-GP 学习笔记的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。