PyTorch实战GANs
GANs簡介
GANs(Generative Adversarial Networks ),全名又叫做生成式對抗網絡,設計者使用的是一種類似于“左右手互博”的思想,所以GANs的作者周伯通(英文名:lan Goodfellow)在設計的時候遵循的就是這個原則。“左右手”分別指代的是GANs中的生成器(Generator)和判別器(Discriminator)。
圖片來源于網絡
生成器的主要作用就是隨機生成一個指定格式的圖片,判別器的主要作用是能夠對輸入的圖片真假進行判斷,下圖就是GANs最原始的網絡架構。
圖片來源于網絡
所以在GANs中重點需要實現的就是生成器和判別器,下面我們通過兩種不同的方式對GANs進行實現,方法一中的生成器和判別器由簡單的神經網絡構成,方法二中生成器和判別器由卷積神經網絡構成。
簡單神經網絡
這里我們重點介紹生成器、判別器的實現以及如何定義模型的損失和優化,完整代碼會在最后貼出來。首先是判別器,這里使用的網絡架構比較簡單,是輸入層-隱藏層-輸出層的三層結構。輸入圖像我們都知道MINST數據集的圖片是28*28的,激活函數使用的LeakyReLU。
class Discriminator(torch.nn.Module):def __init__(self):super(Discriminator,self).__init__()self.discriminator = torch.nn.Sequential(torch.nn.Linear(28*28,128),torch.nn.LeakyReLU(),torch.nn.Linear(128,1))def forward(self, input):output = self.discriminator(input)return output然后是生成器,生成器通過輸入一個指定大小的隨機數生成出28*28的圖片,最后我們生成器生成的圖片越接近真實圖片說明生成器的效果越好。
class Generator(torch.nn.Module):def __init__(self):super(Generator,self).__init__() self.generator = torch.nn.Sequential(torch.nn.Linear(100,128),torch.nn.LeakyReLU(),torch.nn.Linear(128,28*28),torch.nn.Tanh())def forward(self,input):output = self.generator(input)return output生成我們生成器需要用到的隨機數我們使用一個函數來定義。
def rand_img(batchsize,output_size):Z = np.random.uniform(-1.,1., size=(batchsize, output_size))Z = np.float32(Z)Z = torch.from_numpy(Z) Z = Variable(Z.cuda())return Z接下來是損失的定義,我們只要把握住兩個原則,我們希望判別器對輸入的真實圖片全部判斷為1,輸入的虛假圖片全部判斷為0,同時對于生成器我們要求生產的圖片輸入到判別器后能夠被判斷為1。這就是GAN是的精髓,具體實現如下。
model_discriminator = Discriminator_conv().cuda() model_generator = Generator_conv().cuda()X_gen = model_generator(Z) X_gen = X_gen.view(-1,1,28,28) X_train = X_train.view(-1,1,28,28)logits_real = model_discriminator(X_train) logits_fake = model_discriminator(X_gen)d_loss = loss_f(logits_real, torch.ones_like(logits_real))+loss_f(logits_fake, torch.zeros_like(logits_fake))Z = rand_img(batchsize=batchsize, output_size=100) X_gen = model_generator(Z) X_gen = X_gen.view(-1,1,28,28) logits_fake = model_discriminator(X_gen) g_loss = loss_f(logits_fake,torch.ones_like(logits_fake))我們通過訓練減小d_loss來提升判別器的能力,同時又在訓練減小g_loss來提升生產器的能力,這兩個看似矛盾的方向卻可以讓整個模型取得非常好的效果。
卷積神經網絡
使用卷積方式實現的GANs也被稱作為DCGANs,卷積的實現最大的不同就是在模型的結構中加入了卷積的成分,當然最后效果相對前者會更加理想。
判別器,使用的是非常常用的卷積神經網絡結構。
class Discriminator_conv(torch.nn.Module):def __init__(self):super(Discriminator_conv,self).__init__()self.conv = torch.nn.Sequential(torch.nn.Conv2d(1,32,kernel_size=5,stride=1),torch.nn.LeakyReLU(),torch.nn.MaxPool2d(kernel_size=2,stride=2),torch.nn.Conv2d(32,64,kernel_size=5,stride=1),torch.nn.LeakyReLU(),torch.nn.MaxPool2d(kernel_size=2,stride=2))self.dense = torch.nn.Sequential(torch.nn.Linear(64*4*4,64*4*4),torch.nn.LeakyReLU(),torch.nn.Linear(64*4*4,1))def forward(self, input):output = self.conv(input)output = output.view(-1,64*4*4)output = self.dense(output)return output生成器,其中用到的一個逆向卷積的方法,公式如下:
class Generator_conv(torch.nn.Module):def __init__(self):super(Generator_conv,self).__init__()self.conv_dense = torch.nn.Sequential(torch.nn.Linear(100,1024),torch.nn.LeakyReLU(),torch.nn.BatchNorm1d(num_features=1024),torch.nn.Linear(1024,7*7*128),torch.nn.BatchNorm1d(num_features=7*7*128))self.transpose_conv = torch.nn.Sequential(torch.nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1),torch.nn.ReLU(),torch.nn.BatchNorm2d(num_features=64),torch.nn.ConvTranspose2d(64,1,kernel_size=4,stride=2,padding=1),torch.nn.Tanh())def forward(self, input):output = self.conv_dense(input)output = output.view(-1,128,7,7)output = self.transpose_conv(output)return output最后我把模型訓練1個epoch、10個epoch和20個epoch后得到的結果貼出來,可以看出我們的生成器已經可以生成同MINIST數據類似的圖片了。
1個epoch
10個epoch
20個epoch
總結
最后說幾點小的訣竅。
1、我們可以將原來的d_loss改成如下形式。
d_loss = loss_f(logits_real, torch.ones_like(logits_real)*(1-smooth))+loss_f(logits_fake, torch.zeros_like(logits_fake))通過乘上一個(1-smooth)的參數(其中smooth可以設為0.1-0.9)來防止判別器模型的過擬合。
2、通過改變降低優化函數的初始學習速率來降低生成器的g_loss。
optimizer_dis = torch.optim.Adam(model_discriminator.parameters(),lr=0.0001) optimizer_gen = torch.optim.Adam(model_generator.parameters(),lr=0.0001)3、構建更加深度的網絡結構能夠取得更好的結果,當然也會開銷更多的訓練時間。
資源
非常全的GANs衍生模型
完整代碼
import torch import torchvision from torch.autograd import Variable from torchvision import datasets,models,transforms import matplotlib.pyplot as plt import numpy as np%matplotlib inline %config InlineBackend.figure_format="retina"epoch_n =20 batchsize = 128 smooth = 0.1train_transform=transforms.ToTensor()train_data = datasets.MNIST(root="data",download=True,train=True,transform=train_transform) train_load = torch.utils.data.DataLoader(dataset=train_data,shuffle=True,batch_size=batchsize)def plot_img(img):img = torchvision.utils.make_grid(img)img = img.numpy().transpose(1,2,0)plt.figure(figsize=(12,9))plt.imshow(img)class Discriminator_conv(torch.nn.Module):def __init__(self):super(Discriminator_conv,self).__init__()self.conv = torch.nn.Sequential(torch.nn.Conv2d(1,32,kernel_size=5,stride=1),torch.nn.LeakyReLU(),torch.nn.MaxPool2d(kernel_size=2,stride=2),torch.nn.Conv2d(32,64,kernel_size=5,stride=1),torch.nn.LeakyReLU(),torch.nn.MaxPool2d(kernel_size=2,stride=2))self.dense = torch.nn.Sequential(torch.nn.Linear(64*4*4,64*4*4),torch.nn.LeakyReLU(),torch.nn.Linear(64*4*4,1))def forward(self, input):output = self.conv(input)output = output.view(-1,64*4*4)output = self.dense(output)return outputclass Generator_conv(torch.nn.Module):def __init__(self):super(Generator_conv,self).__init__()self.conv_dense = torch.nn.Sequential(torch.nn.Linear(100,1024),torch.nn.LeakyReLU(),torch.nn.BatchNorm1d(num_features=1024),torch.nn.Linear(1024,7*7*128),torch.nn.BatchNorm1d(num_features=7*7*128))self.transpose_conv = torch.nn.Sequential(torch.nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1),torch.nn.ReLU(),torch.nn.BatchNorm2d(num_features=64),torch.nn.ConvTranspose2d(64,1,kernel_size=4,stride=2,padding=1),torch.nn.Tanh())def forward(self, input):output = self.conv_dense(input)output = output.view(-1,128,7,7)output = self.transpose_conv(output)return outputdef initialize_weights(m):if isinstance(m,torch.nn.Linear) or isinstance(m,torch.nn.Conv2d):torch.nn.init.xavier_uniform_(m.weight.data)model_discriminator = Discriminator_conv().cuda() model_discriminator.apply(initialize_weights) model_generator = Generator_conv().cuda() model_generator.apply(initialize_weights)loss_f = torch.nn.BCEWithLogitsLoss()optimizer_dis = torch.optim.Adam(model_discriminator.parameters(),lr=0.0001) optimizer_gen = torch.optim.Adam(model_generator.parameters(),lr=0.0001)samples = [] losses = []def rand_img(batchsize,output_size):Z = np.random.uniform(-1.,1., size=(batchsize, output_size))Z = np.float32(Z)Z = torch.from_numpy(Z) Z = Variable(Z.cuda())return Zfor epoch in range(epoch_n):for batch in train_load:X_train,y_train = batchX_train,y_train = Variable(X_train.cuda()),Variable(y_train.cuda())#X_train,y_train = Variable(X_train),Variable(y_train)Z = rand_img(batchsize=batchsize, output_size=100)optimizer_dis.zero_grad() X_gen = model_generator(Z)X_gen = X_gen.view(-1,1,28,28)X_train = X_train.view(-1,1,28,28)logits_real = model_discriminator(X_train)logits_fake = model_discriminator(X_gen)d_loss = loss_f(logits_real, torch.ones_like(logits_real)*(1-smooth))+loss_f(logits_fake, torch.zeros_like(logits_fake))d_loss.backward(retain_graph=True)optimizer_dis.step()optimizer_gen.zero_grad() Z = rand_img(batchsize=batchsize, output_size=100)X_gen = model_generator(Z)X_gen = X_gen.view(-1,1,28,28)logits_fake = model_discriminator(X_gen)g_loss = loss_f(logits_fake,torch.ones_like(logits_fake)) g_loss.backward() optimizer_gen.step()print("Epoch{}/{}...".format(epoch+1, epoch_n),"Discriminator Loss:{:.4f}...".format(d_loss),"Generator Loss:{:.4f}...".format(g_loss))losses.append((d_loss, g_loss))fake_img = model_generator(Z)samples.append(fake_img)fig, ax = plt.subplots() losses = np.array(losses) plt.plot(losses.T[0], label='Discriminator') plt.plot(losses.T[1], label='Generator') plt.title("Training Losses") plt.legend()def to_img(img):img = img.detach().cpu().dataimg = img.clamp(0,1)img = img.view(-1,1,28,28)return imgfor i in range(len(samples)):img = to_img(samples[i])plot_img(img)https://zhuanlan.zhihu.com/p/40393929
總結
以上是生活随笔為你收集整理的PyTorch实战GANs的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【译】Persistent revers
- 下一篇: 豆瓣图书的推荐与搜索、简易版知识引擎构建