tensorflow学习(7. GAN实现MNIST分类)
https://blog.csdn.net/CoderPai/article/details/70598403?utm_source=blogxgwz0
里面有比較全面的GAN的鏈接
原始論文鏈接:http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
一篇不錯的理解GAN的文章:https://blog.csdn.net/qq_31531635/article/details/70670271
?
這篇GAN代碼的出處?https://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/
?
簡單用了別人的代碼,實現了一下,加入了自己理解的部分:
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import os from tensorflow.examples.tutorials.mnist import input_datasess = tf.InteractiveSession()mb_size = 128 Z_dim = 100mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)def weight_var(shape, name):return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.xavier_initializer())def bias_var(shape, name):return tf.get_variable(name=name, shape=shape, initializer=tf.constant_initializer(0))# discriminater net #普通的兩層卷積網絡,作為鑒別網絡 X = tf.placeholder(tf.float32, shape=[None, 784], name='X')D_W1 = weight_var([784, 128], 'D_W1') D_b1 = bias_var([128], 'D_b1')D_W2 = weight_var([128, 1], 'D_W2') D_b2 = bias_var([1], 'D_b2')theta_D = [D_W1, D_W2, D_b1, D_b2]# generator net # 兩層網絡,輸入為100維的噪聲,這里是[-1,1]的均勻噪聲,作為生成網絡 Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z')G_W1 = weight_var([100, 128], 'G_W1') G_b1 = bias_var([128], 'G_B1')G_W2 = weight_var([128, 784], 'G_W2') G_b2 = bias_var([784], 'G_B2')theta_G = [G_W1, G_W2, G_b1, G_b2]#具體網絡的結構def generator(z):G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)G_log_prob = tf.matmul(G_h1, G_W2) + G_b2G_prob = tf.nn.sigmoid(G_log_prob) #使用sigmoid給出該位置的值return G_probdef discriminator(x):D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)D_logit = tf.matmul(D_h1, D_W2) + D_b2D_prob = tf.nn.sigmoid(D_logit) #使用sigmoid給出該位置的值return D_prob, D_logitG_sample = generator(Z) #X為實際的樣本數據,G_sample為生成的樣本數據 D_real, D_logit_real = discriminator(X) D_fake, D_logit_fake = discriminator(G_sample) ''' D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake)) G_loss = -tf.reduce_mean(tf.log(D_fake)) ''' #D為辨別器,G為生成器,這里G是有助于辨別器提高性能的,G的輸入是隨機噪聲,如果G沒有訓練,產出的應該是無關樣本 #雖然也會提高一點性能,但是肯定不好,這里是希望G可以將噪聲映射到合理的數字圖的空間上,這里希望產出應該很接近 #合理圖片,那么D很有可能被判別圖像為真實圖像,所以G_loss最小化對應為G_sample被D網絡認定成真實圖像 D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real))) D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake))) D_loss = D_loss_real + D_loss_fake G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))#D和G還是比較獨立的兩部分,分別寫開,不過兩部分需要互相提高,所以后續的訓練應該是交替進行 D_optimizer = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D) G_optimizer = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G) #隨機數生成 def sample_Z(m, n):'''Uniform prior for G(Z)'''return np.random.uniform(-1., 1., size=[m, n])def plot(samples):fig = plt.figure(figsize=(4, 4))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples): # [i,samples[i]] imax=16ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28, 28), cmap='Greys_r')return figif not os.path.exists('out/'):os.makedirs('out/')sess.run(tf.global_variables_initializer())i = 0 for it in range(1000000):#每1000次輸出一次if it % 1000 == 0:samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)}) # 16*784fig = plot(samples)#圖像存儲plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')i += 1plt.close(fig)X_mb, _ = mnist.train.next_batch(mb_size)#D和G的交替訓練,進行性能的互相提高_, D_loss_curr = sess.run([D_optimizer, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})_, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})if it % 1000 == 0:print('Iter: {}'.format(it))print('D loss: {:.4}'.format(D_loss_curr))print('G_loss: {:.4}'.format(G_loss_curr))print()運行的結果是每1000次生成的圖像,前200張做成了視頻,結果如下:
總結一下:
1.GAN還是一種解決問題的框架,通過生成網絡G產生更高相關度的圖像來提升判別網絡D的性能
2.本文方法僅僅使用了神經網絡,沒有使用CNN,使用CNN產生判別網絡才是更合適的,本文在產生的200張的動圖后半部分圖像變化不明顯,將Loss顯示出來也可以看出來結果提升不高,這是網絡結構較差的原因。
3.GAN不止是提高判別網絡D的性能,也可以通過GAN的生成網絡G產生平價數據(弱監督中,有標簽數據較少,無標簽數據較多的情形)。可以通過GAN的生成網絡產生更多的有標簽數據用于訓練,論文見SSGAN。博客鏈接:https://blog.csdn.net/shenxiaolu1984/article/details/75736407
?
之后要看的一些東西:
1.Fine-tunning,復用別人的網絡并進行新的應用開發,學習網址鏈接:https://blog.csdn.net/u011600477/article/details/78607883
2.RCNN,博客見:https://blog.csdn.net/v1_vivian/article/details/78599229?utm_source=blogxgwz0,https://blog.csdn.net/WoPawn/article/details/52133338。RCNN關鍵在于預搜索框的選取,論文是Selective Search for Object Recognition,博客鏈接:https://blog.csdn.net/surgewong/article/details/39316931
3.YOLO: You Only Look Once,和RCNN的功能一致,但是想法不同,博客:https://blog.csdn.net/shenxiaolu1984/article/details/78826995
4.可視化網絡結構特征
?
總結
以上是生活随笔為你收集整理的tensorflow学习(7. GAN实现MNIST分类)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tensorflow学习(6.Alexn
- 下一篇: 论文笔记 《Selective Sear