WGAN-GP代码注释
生活随笔
收集整理的這篇文章主要介紹了
WGAN-GP代码注释
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
WGAN-GP代碼注釋
本文鏈接:https://blog.csdn.net/qq_20943513/article/details/73129308
代碼地址:https://github.com/bojone/gan/
對(duì)于這個(gè)基于tensorflow實(shí)現(xiàn)的代碼,我對(duì)其進(jìn)行了簡(jiǎn)單的注釋。
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import os import numpy as np from scipy import misc,ndimage#讀入本地的MNIST數(shù)據(jù)集,該函數(shù)為mnist專用 mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)batch_size = 100 #每個(gè)batch的大小 width,height = 28,28 #每張圖片包含28*28個(gè)像素點(diǎn) mnist_dim = width*height #用一個(gè)數(shù)字?jǐn)?shù)組表示一張圖,那么這個(gè)數(shù)組展開成向量的長(zhǎng)度就是28*28=784 random_dim = 10 #每張圖表示一個(gè)數(shù)字,從0到9 epochs = 1000000 #共100萬(wàn)輪def my_init(size): #從[-0.05,0.05]的均勻分布中采樣得到維度是size的輸出return tf.random_uniform(size, -0.05, 0.05)#判別器相關(guān)參數(shù)設(shè)定 D_W1 = tf.Variable(my_init([mnist_dim, 128])) #784*128 D_b1 = tf.Variable(tf.zeros([128])) #長(zhǎng)度為128的一維張量,值均為0 D_W2 = tf.Variable(my_init([128, 32])) D_b2 = tf.Variable(tf.zeros([32])) D_W3 = tf.Variable(my_init([32, 1])) D_b3 = tf.Variable(tf.zeros([1])) D_variables = [D_W1, D_b1, D_W2, D_b2, D_W3, D_b3]#生成器相關(guān)參數(shù)設(shè)定 G_W1 = tf.Variable(my_init([random_dim, 32])) G_b1 = tf.Variable(tf.zeros([32])) G_W2 = tf.Variable(my_init([32, 128])) G_b2 = tf.Variable(tf.zeros([128])) G_W3 = tf.Variable(my_init([128, mnist_dim])) G_b3 = tf.Variable(tf.zeros([mnist_dim])) G_variables = [G_W1, G_b1, G_W2, G_b2, G_W3, G_b3]#判別器網(wǎng)絡(luò)結(jié)構(gòu) def D(X):X = tf.nn.relu(tf.matmul(X, D_W1) + D_b1) #X的維度是100*784,D_W1維度是784*128,得到結(jié)果維度為100*128X = tf.nn.relu(tf.matmul(X, D_W2) + D_b2) #X的維度是100*128,D_W2維度是128*32,得到結(jié)果維度為100*32X = tf.matmul(X, D_W3) + D_b3 #X的維度是100*32,D_W3維度是32*1,得到結(jié)果維度為100*1return X#生成器網(wǎng)絡(luò)結(jié)構(gòu) def G(X):X = tf.nn.relu(tf.matmul(X, G_W1) + G_b1) #X的維度是100*10,G_W1維度是10*32,得到結(jié)果維度為100*32X = tf.nn.relu(tf.matmul(X, G_W2) + G_b2) #X的維度是100*32,G_W2維度是32*128,得到結(jié)果維度為100*128X = tf.nn.sigmoid(tf.matmul(X, G_W3) + G_b3) #X的維度是100*128,G_W3維度是128*784,得到結(jié)果維度為100*784return X#real_X是真實(shí)樣本,random_X是噪音數(shù)據(jù),random_Y是生成器生成的偽樣本 real_X = tf.placeholder(tf.float32, shape=[batch_size, mnist_dim]) random_X = tf.placeholder(tf.float32, shape=[batch_size, random_dim]) random_Y = G(random_X)#求懲罰項(xiàng),這個(gè)這個(gè)懲罰是“軟約束”,最終的結(jié)果不一定滿足這個(gè)約束,但是會(huì)在約束上下波動(dòng)。這里L(fēng)ipschitz約束的C=1 eps = tf.random_uniform([batch_size, 1], minval=0., maxval=1.) #eps是U[0,1]的隨機(jī)數(shù) X_inter = eps*real_X + (1. - eps)*random_Y #在真實(shí)樣本和生成樣本之間隨機(jī)插值,希望這個(gè)約束可以“布滿”真實(shí)樣本和生成樣本之間的空間 grad = tf.gradients(D(X_inter), [X_inter])[0] #求梯度 grad_norm = tf.sqrt(tf.reduce_sum((grad)**2, axis=1)) #求梯度的二范數(shù) grad_pen = 10 * tf.reduce_mean(tf.nn.relu(grad_norm - 1.)) #Lipschitz限制是要求判別器的梯度不超過(guò)K,這個(gè)loss項(xiàng)是希望判別器的梯度離K(此處K設(shè)為1)越近越好#判別器和生成器的損失函數(shù) D_loss = tf.reduce_mean(D(real_X)) - tf.reduce_mean(D(random_Y)) + grad_pen G_loss = tf.reduce_mean(D(random_Y)) #越接近真實(shí)樣本越好#判別器和生成器的優(yōu)化函數(shù) D_solver = tf.train.AdamOptimizer(1e-4, 0.5).minimize(D_loss, var_list=D_variables) G_solver = tf.train.AdamOptimizer(1e-4, 0.5).minimize(G_loss, var_list=G_variables)#創(chuàng)建對(duì)話,初始化所有變量 sess = tf.Session() sess.run(tf.global_variables_initializer())#是否存在“out”文件夾,不存在的話新建一個(gè),存放實(shí)驗(yàn)結(jié)果 if not os.path.exists('out/'):os.makedirs('out/')for e in range(epochs):for i in range(5): #每輪計(jì)算5個(gè)batchreal_batch_X,_ = mnist.train.next_batch(batch_size) #隨機(jī)抓取訓(xùn)練數(shù)據(jù)中的100個(gè)批處理數(shù)據(jù)點(diǎn)random_batch_X = np.random.uniform(-1, 1, (batch_size, random_dim)) #從均勻分布中采樣,輸出100*10個(gè)樣本_,D_loss_ = sess.run([D_solver,D_loss], feed_dict={real_X:real_batch_X, random_X:random_batch_X})random_batch_X = np.random.uniform(-1, 1, (batch_size, random_dim))_,G_loss_ = sess.run([G_solver,G_loss], feed_dict={random_X:random_batch_X})#每1000輪輸出一次當(dāng)前結(jié)果if e % 1000 == 0:print 'epoch %s, D_loss: %s, G_loss: %s'%(e, D_loss_, G_loss_)n_rows = 6check_imgs = sess.run(random_Y, feed_dict={random_X:random_batch_X}).reshape((batch_size, width, height))[:n_rows*n_rows] #由生成器得到偽樣本,維度為100*784,reshape為100個(gè)28*28的矩陣,取6*6個(gè)矩陣構(gòu)成一張圖imgs = np.ones((width*n_rows+5*n_rows+5, height*n_rows+5*n_rows+5)) #203*203的值為1的二維矩陣for i in range(n_rows*n_rows):imgs[5+5*(i%n_rows)+width*(i%n_rows):5+5*(i%n_rows)+width+width*(i%n_rows), 5+5*(i/n_rows)+height*(i/n_rows):5+5*(i/n_rows)+height+height*(i/n_rows)] = check_imgs[i]misc.imsave('out/%s.png'%(e/1000), imgs)總結(jié)
以上是生活随笔為你收集整理的WGAN-GP代码注释的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Windows10+GPU版 pytor
- 下一篇: WGAN-GP与GAN及WGAN的比较