使用Jittor实现Conditional GAN
Jittor實現Conditional GAN
Generative Adversarial Nets(GAN)提出了一種新的方法來訓練生成模型。然而,GAN對于要生成的圖片缺少控制。Conditional GAN(CGAN)通過添加顯式的條件或標簽,來控制生成的圖像。本文講解了CGAN的網絡結構、損失函數設計、使用CGAN生成一串數字、從頭訓練CGAN、以及在mnist手寫數字數據集上的訓練結果。
CGAN網絡架構
通過在生成器generator和判別器discriminator中添加相同的額外信息y,GAN就可以擴展為一個conditional模型。y可以是任何形式的輔助信息,例如類別標簽或者其他形式的數據。可以通過將y作為額外輸入層,添加到生成器和判別器來完成條件控制。
在生成器generator中,除了y之外,還額外輸入隨機一維噪聲z,為結果生成提供更多靈活性。
損失函數
GAN的損失函數
在解釋CGAN的損失函數之前,首先介紹GAN的損失函數。下面是GAN的損失函數設計。
對于判別器D,要訓練最大化這個loss。如果D的輸入是來自真實樣本的數據x,則D的輸出D(x)要盡可能地大,log(D(x))也會盡可能大。如果D的輸入是來自G生成的假圖片G(z),則D的輸出D(G(z))應盡可能地小,從而log(1-D(G(z))會盡可能地大。這樣可以達到max D的目的。
對于生成器G,要訓練最小化這個loss。對于G生成的假圖片G(z),希望盡可能地騙過D,讓它覺得生成的圖片就是真的圖片,這樣就達到了G“以假亂真”的目的。那么D的輸出D(G(z))應盡可能地大,從而log(1-D(G(z))會盡可能地小。這樣可以達到min G的目的。
D和G以這樣的方式聯合訓練,最終達到G的生成能力越來越強,D的判別能力越來越強的目的。
CGAN的損失函數
下面是CGAN的損失函數設計。
很明顯,CGAN的loss跟GAN的loss的區別就是多了條件限定y。D(x/y)代表在條件y下,x為真的概率。D(G(z/y))表示在條件y下,G生成的圖片被D判別為真的概率。
Jittor代碼數字生成
首先,導入需要的包,并且設置好所需的超參數:
import jittor as jt
from jittor import nn
import numpy as np
import pylab as pl
%matplotlib inline
隱空間向量長度
latent_dim = 100
類別數量
n_classes = 10
圖片大小
img_size = 32
圖片通道數量
channels = 1
圖片張量的形狀
img_shape = (channels, img_size, img_size)
第一步,定義生成器G。該生成器輸入兩個一維向量y和noise,生成一張圖片。
class Generator(nn.Module):
def init(self):
super(Generator, self).init()
self.label_emb = nn.Embedding(n_classes, n_classes)
def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2))return layersself.model = nn.Sequential(*block((latent_dim + n_classes), 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())def execute(self, noise, labels):gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)img = self.model(gen_input)img = img.view((img.shape[0], *img_shape))return img
第二步,定義判別器D。D輸入一張圖片和對應的y,輸出是真圖片的概率。
class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()
self.label_embedding = nn.Embedding(n_classes, n_classes)
self.model = nn.Sequential(
nn.Linear((n_classes + int(np.prod(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 1))
def execute(self, img, labels):d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)validity = self.model(d_in)return validity
第三步,使用CGAN生成一串數字。
代碼如下。可以使用訓練好的模型來生成圖片,也可以使用提供的預訓練參數: 模型預訓練參數下載:https://cloud.tsinghua.edu.cn/d/fbe30ae0967942f6991c/。
下載提供的預訓練參數
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/generator_last.pkl
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/discriminator_last.pkl
生成自定義的數字:
定義模型
generator = Generator()
discriminator = Discriminator()
generator.eval()
discriminator.eval()
加載參數
generator.load(’./generator_last.pkl’)
discriminator.load(’./discriminator_last.pkl’)
定義一串數字
number = “201962517”
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z,labels)
pl.imshow(gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1)))
生成結果如下,測試的完整代碼在https://github.com/Jittor/gan-jittor/blob/master/models/cgan/test.py。
從頭訓練Condition GAN
從頭訓練 Condition GAN 的完整代碼在https://github.com/Jittor/gan-jittor/blob/master/models/cgan/cgan.py,下載下來看看!
!wget https://raw.githubusercontent.com/Jittor/gan-jittor/master/models/cgan/cgan.py
!python3.7 ./cgan.py --help
選擇合適的batch size,運行試試
運行命令: !python3.7 ./cgan.py --batch_size 8
下載下來的代碼里面定義損失函數、數據集、優化器。損失函數采用MSELoss、數據集采用MNIST、優化器采用Adam 如下(此段代碼僅僅用于解釋意圖,不能運行,需要運行請運行完整文件cgan.py):
此段代碼僅僅用于解釋意圖,不能運行,需要運行請運行完整文件cgan.py
Define Loss
adversarial_loss = nn.MSELoss()
Define Model
generator = Generator()
discriminator = Discriminator()
Define Dataloader
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
transform = transform.Compose([
transform.Resize(opt.img_size),
transform.Gray(),
transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)
optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
模型訓練的代碼如下(此段代碼僅僅用于解釋意圖,不能運行,需要運行請運行完整文件cgan.py):
此段代碼僅僅用于解釋意圖,不能運行,需要運行請運行完整文件cgan.py
valid表示真,fake表示假
valid = jt.ones([batch_size, 1]).float32().stop_grad()
fake = jt.zeros([batch_size, 1]).float32().stop_grad()
真實圖像和對應的標簽
real_imgs = jt.array(imgs)
labels = jt.array(labels)
#########################################################
訓練生成器G
- 希望生成的圖片盡可能地讓D覺得是valid
#########################################################
隨機向量z和隨機生成的標簽
z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()
隨機向量z和隨機生成的標簽經過生成器G生成的圖片,希望判別器能夠認為生成的圖片和生成的標簽是一致的,以此優化生成器G的生成能力。
gen_imgs = generator(z, gen_labels)
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)
g_loss.sync()
optimizer_G.step(g_loss)
#########################################################
訓練判別器D
- 盡可能識別real_imgs為valid
- 盡可能識別gen_imgs為fake
#########################################################
真實的圖片和標簽經過判別器的結果,要盡可能接近valid。
validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real, valid)
G生成的圖片和對應的標簽經過判別器的結果,要盡可能接近fake。
validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake, fake)
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.sync()
optimizer_D.step(d_loss)
MNIST數據集訓練結果
下面展示了Jittor版CGAN在MNIST數據集的訓練結果。下面分別是訓練0 epoch和90 epoches的結果。
總結
以上是生活随笔為你收集整理的使用Jittor实现Conditional GAN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 计图(Jittor) 1.1版本:新增骨
- 下一篇: XLearning - 深度学习调度平台