深度学习之基于DCGAN实现手写数字生成
生活随笔
收集整理的這篇文章主要介紹了
深度学习之基于DCGAN实现手写数字生成
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
該篇文章與上篇文章內容相差不多,但是主要的網絡結構不同,上篇文章采用的是GAN網絡結構,而這篇文章采用的是DCGAN網絡結構。兩者的差異在于以下幾點:
(1)使用卷積和去卷積代替池化層。
(2)在生成器和判別器中都添加了批量歸一化操作。
(3)去掉了全連接層,使用全局池化層替代。
(4)生成器的輸出層使用Tanh 激活函數,其他層使用RELU。
(5)判別器的所有層都是用LeakyReLU 激活函數。
其中最本質的一點就是使用了卷積和去卷積。
1.導入庫
import numpy as np import glob,imageio,os,PIL import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import layers2.數據準備
歸一化-打亂-batch
(train_images,train_labels),(_,_) = tf.keras.datasets.mnist.load_data() train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32') train_images = (train_images - 127.5) / 127.5#歸一化到[-1,1]之間 batch_size = 256 buffer_size = 60000 datasets = tf.data.Dataset.from_tensor_slices(train_images).shuffle(buffer_size).batch(batch_size)3.生成器與判別器的構建
生成器的構建中,采用了tf.keras.layers.Conv2DTranspose,也就是去卷積,它的目的是將經過池化層以后縮小的矩陣擴大到一定的大小,比如說從3 * 3 擴大到 5 * 5,如下圖所示:
網絡結構如下所示:
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 12544) 1254400 _________________________________________________________________ batch_normalization (BatchNo (None, 12544) 50176 _________________________________________________________________ leaky_re_lu (LeakyReLU) (None, 12544) 0 _________________________________________________________________ reshape (Reshape) (None, 7, 7, 256) 0 _________________________________________________________________ conv2d_transpose (Conv2DTran (None, 7, 7, 128) 819200 _________________________________________________________________ batch_normalization_1 (Batch (None, 7, 7, 128) 512 _________________________________________________________________ leaky_re_lu_1 (LeakyReLU) (None, 7, 7, 128) 0 _________________________________________________________________ conv2d_transpose_1 (Conv2DTr (None, 14, 14, 64) 204800 _________________________________________________________________ batch_normalization_2 (Batch (None, 14, 14, 64) 256 _________________________________________________________________ leaky_re_lu_2 (LeakyReLU) (None, 14, 14, 64) 0 _________________________________________________________________ conv2d_transpose_2 (Conv2DTr (None, 28, 28, 1) 1600 ================================================================= Total params: 2,330,944 Trainable params: 2,305,472 Non-trainable params: 25,472 _________________________________________________________________判別器的構建:
def Disciminator_model():model = tf.keras.Sequential([tf.keras.layers.Conv2D(64,(5,5),strides=(2,2),padding="same",input_shape=(28,28,1)),tf.keras.layers.LeakyReLU(),tf.keras.layers.Dropout(0.3),tf.keras.layers.Conv2D(128,(5,5),strides=(2,2),padding="same"),tf.keras.layers.LeakyReLU(),tf.keras.layers.Dropout(0.3),tf.keras.layers.Flatten(),tf.keras.layers.Dense(1,activation='sigmoid')])return model discriminator = Disciminator_model()網絡結構如下所示:
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 14, 14, 64) 1664 _________________________________________________________________ leaky_re_lu_3 (LeakyReLU) (None, 14, 14, 64) 0 _________________________________________________________________ dropout (Dropout) (None, 14, 14, 64) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 7, 7, 128) 204928 _________________________________________________________________ leaky_re_lu_4 (LeakyReLU) (None, 7, 7, 128) 0 _________________________________________________________________ dropout_1 (Dropout) (None, 7, 7, 128) 0 _________________________________________________________________ flatten (Flatten) (None, 6272) 0 _________________________________________________________________ dense_1 (Dense) (None, 1) 6273 ================================================================= Total params: 212,865 Trainable params: 212,865 Non-trainable params: 0 _________________________________________________________________這一部分是與上篇文章代碼的本質差異。
4.其余操作
由于硬件原因,epochs設置的是60。如果有不明白的地方,請參考上篇文章
#loss值 cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) #Dis_loss def Discriminator_loss(real_out,fake_out):real_loss = cross_entropy(tf.ones_like(real_out),real_out)fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)return real_loss+fake_loss #Gen_loss def Generator_loss(fake_out):return cross_entropy(tf.ones_like(fake_out),fake_out) #Dis_opt generator_optimizer = tf.keras.optimizers.Adam(1e-4) discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)epochs = 60 noise_dim = 100 num_exp_to_generate = 16seed = tf.random.normal([num_exp_to_generate,noise_dim])def train_step(images):noise = tf.random.normal([batch_size,noise_dim])with tf.GradientTape() as gen_tape,tf.GradientTape() as dis_tape:gen_images = generator(noise,training = True)real_out = discriminator(images,training = True)fake_out = discriminator(gen_images,training = True)gen_loss = Generator_loss(fake_out)dis_loss = Discriminator_loss(real_out,fake_out)gen_gradient = gen_tape.gradient(gen_loss,generator.trainable_variables)dis_gradient = dis_tape.gradient(dis_loss,discriminator.trainable_variables)generator_optimizer.apply_gradients(zip(gen_gradient,generator.trainable_variables))discriminator_optimizer.apply_gradients(zip(dis_gradient,discriminator.trainable_variables))def Generator_plot_image(gen_model,test_noise,epoch):pre_images = gen_model(test_noise,training = False)#根據test_noise生成圖片,生成器設置為不可訓練fig = plt.figure(figsize=(4,4))for i in range(pre_images.shape[0]):plt.subplot(4,4,i+1)plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray')#之前歸一化為[-1,1]之間,現在+1然后除以2,使之在[0,1]之間plt.axis('off')fig.savefig("E:/tmp/.keras/datasets/num_gen_DCGAN/%05d.png" % epoch)plt.close()def train(dataset,epochs):for epoch in range(epochs):for img_batch in dataset:train_step(img_batch)print('.',end = '')print()Generator_plot_image(generator, seed, epoch) train(datasets,epochs)產生的結果如下所示(服務器到期了,然后沒有運行完,這是效果比較好的一個圖片):
努力加油a啊
創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎總結
以上是生活随笔為你收集整理的深度学习之基于DCGAN实现手写数字生成的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Spring是什么意思
- 下一篇: 深度学习之基于DCGAN实现动漫人物的生