54_pytorch GAN(生成对抗网络)、Gan代码示例、WGAN代码示例
1.54.GAN(生成對抗網絡)
1.54.1.什么是GAN
2014 年,Ian Goodfellow 和他在蒙特利爾大學的同事發表了一篇震撼學界的論文。沒錯,我說的就是《Generative Adversarial Nets》,這標志著生成對抗網絡(GAN)的誕生,而這是通過對計算圖和博弈論的創新性結合。他們的研究展示,給定充分的建模能力,兩個博弈模型能夠通過簡單的反向傳播(backpropagation)來協同訓練。
這兩個模型的角色定位十分鮮明。給定真實數據集R,G是生成器(generator),它的任務是生成能以假亂真的假數據;而D是判別器(discriminator),它從真實數據集或者G那里獲取數據,然后做出判別真假的標記。lan Goodfellow的比喻是,G就像一個贗品作坊,想要讓做出來的東西盡可能接近真品,蒙混過關。而D就是文物鑒定專家,要能區分出真品和高仿(但在這個例子中,造假者G看不到原始數據,而只有D的鑒定結果—前者是在盲干)。
理想情況下,D和G都會隨著不斷訓練,做的越來越好----直到G基本上成了一個”贗品制造大師”,而D因無法正確區分兩種數據分布輸給G。
一、GAN(Generative Adversarial Nets)
神經網絡有很多種,常見的有如下幾種:
1.普通的前向傳播網絡
2.用于分析圖像的卷積神經網絡。
3.用于分析語音或文字等序列信息的RNN神經網絡。
以上三種網絡都有一個共同點,就是通過數據和結果相關聯,來實現自己網絡的功能
還有一種比較特殊,可以理解為用來造數據的GAN網絡 (生成對抗網絡)
Generator根據隨機數隨機生成有意義的數據,Discriminator用來學習哪些數據是真實的,哪些數據是生成的然后反向傳遞給Generator,以此來生成更多有價值的數據。所以生成對抗網絡就是兩個網絡,一個生成,一個對抗,對抗的結果是為了讓生成網絡達到預期的功能。
通過自己的學習過程理解,我認為G網絡的目的就是輸入隨機數,但是可以根據隨機數產生數據,產生的數據好不好由D網絡說的算,D網絡對于現有的數據進行學習和總結,然后指導G網絡產生類似于現有的數據,D網絡扮演了指導的作用。
最后就可以實現,對于輸入的任意分布的隨機數據,都可以產生和原數據相似的數據用于其他的用途,以上是我對GAN網絡更樸素的理解
1.54.2.How to train
1.54.3.Gan代碼示例
# -*- coding: UTF-8 -*-import randomimport numpy as np import torch import visdom from matplotlib import pyplot as plt from torch import nn, optim, autogradh_dim = 400 batchsz = 512 viz = visdom.Visdom()class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 2),)def forward(self, z):output = self.net(z)return outputclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 1),nn.Sigmoid())def forward(self, x):output = self.net(x)return output.view(-1)def data_generator():scale = 2.centers = [(1, 0),(-1, 0),(0, 1),(0, -1),(1. / np.sqrt(2), 1. / np.sqrt(2)),(1. / np.sqrt(2), -1. / np.sqrt(2)),(-1. / np.sqrt(2), 1. / np.sqrt(2)),(-1. / np.sqrt(2), -1. / np.sqrt(2))]centers = [(scale * x, scale * y) for x, y in centers]while True:dataset = []for i in range(batchsz):point = np.random.randn(2) * .02center = random.choice(centers)point[0] += center[0]point[1] += center[1]dataset.append(point)dataset = np.array(dataset, dtype='float32')dataset /= 1.414 # stdevyield dataset# for i in range(100000//25):# for x in range(-2, 3):# for y in range(-2, 3):# point = np.random.randn(2).astype(np.float32) * 0.05# point[0] += 2 * x# point[1] += 2 * y# dataset.append(point)## dataset = np.array(dataset)# print('dataset:', dataset.shape)# viz.scatter(dataset, win='dataset', opts=dict(title='dataset', webgl=True))## while True:# np.random.shuffle(dataset)## for i in range(len(dataset)//batchsz):# yield dataset[i*batchsz : (i+1)*batchsz]def generate_image(D, G, xr, epoch):"""Generates and saves a plot of the true distribution, the generator, and thecritic."""N_POINTS = 128RANGE = 3plt.clf()points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]points = points.reshape((-1, 2))# (16384, 2)# print('p:', points.shape)# draw contourwith torch.no_grad():points = torch.Tensor(points).cuda() # [16384, 2]disc_map = D(points).cpu().numpy() # [16384]x = y = np.linspace(-RANGE, RANGE, N_POINTS)cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())plt.clabel(cs, inline=1, fontsize=10)# plt.colorbar()# draw sampleswith torch.no_grad():z = torch.randn(batchsz, 2).cuda() # [b, 2]samples = G(z).cpu().numpy() # [b, 2]plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.')plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')viz.matplot(plt, win='contour', opts=dict(title='p(x):%d' % epoch))def weights_init(m):if isinstance(m, nn.Linear):# m.weight.data.normal_(0.0, 0.02)nn.init.kaiming_normal_(m.weight)m.bias.data.fill_(0)def gradient_penalty(D, xr, xf):""":param D::param xr::param xf::return:"""LAMBDA = 0.3# only constrait for Discriminatorxf = xf.detach()xr = xr.detach()# [b, 1] => [b, 2]alpha = torch.rand(batchsz, 1).cuda()alpha = alpha.expand_as(xr)interpolates = alpha * xr + ((1 - alpha) * xf)interpolates.requires_grad_()disc_interpolates = D(interpolates)gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,grad_outputs=torch.ones_like(disc_interpolates),create_graph=True, retain_graph=True, only_inputs=True)[0]gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDAreturn gpdef main():torch.manual_seed(23)np.random.seed(23)G = Generator().cuda()D = Discriminator().cuda()G.apply(weights_init)D.apply(weights_init)optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))data_iter = data_generator()print('batch:', next(data_iter).shape)viz.line([[0, 0]], [0], win='loss', opts=dict(title='loss',legend=['D', 'G']))for epoch in range(50000):# 1. train discriminator for k stepsfor _ in range(5):x = next(data_iter)xr = torch.from_numpy(x).cuda()# [b]predr = (D(xr))# max log(lossr)lossr = - (predr.mean())# [b, 2]z = torch.randn(batchsz, 2).cuda()# stop gradient on G# [b, 2]xf = G(z).detach()# [b]predf = (D(xf))# min predflossf = (predf.mean())# gradient penaltygp = gradient_penalty(D, xr, xf)loss_D = lossr + lossf + gpoptim_D.zero_grad()loss_D.backward()# for p in D.parameters():# print(p.grad.norm())optim_D.step()# 2. train Generatorz = torch.randn(batchsz, 2).cuda()xf = G(z)predf = (D(xf))# max predfloss_G = - (predf.mean())optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')generate_image(D, G, xr, epoch)print(loss_D.item(), loss_G.item())if __name__ == '__main__':main()1.54.4.WGAN代碼示例
import torch from torch import nn, optim, autograd import numpy as np import visdom from torch.nn import functional as F from matplotlib import pyplot as plt import randomh_dim = 400 batchsz = 512 viz = visdom.Visdom()class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 2),)def forward(self, z):output = self.net(z)return outputclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 1),nn.Sigmoid())def forward(self, x):output = self.net(x)return output.view(-1)def data_generator():scale = 2.centers = [(1, 0),(-1, 0),(0, 1),(0, -1),(1. / np.sqrt(2), 1. / np.sqrt(2)),(1. / np.sqrt(2), -1. / np.sqrt(2)),(-1. / np.sqrt(2), 1. / np.sqrt(2)),(-1. / np.sqrt(2), -1. / np.sqrt(2))]centers = [(scale * x, scale * y) for x, y in centers]while True:dataset = []for i in range(batchsz):point = np.random.randn(2) * .02center = random.choice(centers)point[0] += center[0]point[1] += center[1]dataset.append(point)dataset = np.array(dataset, dtype='float32')dataset /= 1.414 # stdevyield dataset# for i in range(100000//25):# for x in range(-2, 3):# for y in range(-2, 3):# point = np.random.randn(2).astype(np.float32) * 0.05# point[0] += 2 * x# point[1] += 2 * y# dataset.append(point)## dataset = np.array(dataset)# print('dataset:', dataset.shape)# viz.scatter(dataset, win='dataset', opts=dict(title='dataset', webgl=True))## while True:# np.random.shuffle(dataset)## for i in range(len(dataset)//batchsz):# yield dataset[i*batchsz : (i+1)*batchsz]def generate_image(D, G, xr, epoch):"""Generates and saves a plot of the true distribution, the generator, and thecritic."""N_POINTS = 128RANGE = 3plt.clf()points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]points = points.reshape((-1, 2))# (16384, 2)# print('p:', points.shape)# draw contourwith torch.no_grad():points = torch.Tensor(points).cuda() # [16384, 2]disc_map = D(points).cpu().numpy() # [16384]x = y = np.linspace(-RANGE, RANGE, N_POINTS)cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())plt.clabel(cs, inline=1, fontsize=10)# plt.colorbar()# draw sampleswith torch.no_grad():z = torch.randn(batchsz, 2).cuda() # [b, 2]samples = G(z).cpu().numpy() # [b, 2]plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.')plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))def weights_init(m):if isinstance(m, nn.Linear):# m.weight.data.normal_(0.0, 0.02)nn.init.kaiming_normal_(m.weight)m.bias.data.fill_(0)def gradient_penalty(D, xr, xf):""":param D::param xr::param xf::return:"""LAMBDA = 0.3# only constrait for Discriminatorxf = xf.detach()xr = xr.detach()# [b, 1] => [b, 2]alpha = torch.rand(batchsz, 1).cuda()alpha = alpha.expand_as(xr)interpolates = alpha * xr + ((1 - alpha) * xf)interpolates.requires_grad_()disc_interpolates = D(interpolates)gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,grad_outputs=torch.ones_like(disc_interpolates),create_graph=True, retain_graph=True, only_inputs=True)[0]gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDAreturn gpdef main():torch.manual_seed(23)np.random.seed(23)G = Generator().cuda()D = Discriminator().cuda()G.apply(weights_init)D.apply(weights_init)optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))data_iter = data_generator()print('batch:', next(data_iter).shape)viz.line([[0,0]], [0], win='loss', opts=dict(title='loss',legend=['D', 'G']))for epoch in range(50000):# 1. train discriminator for k stepsfor _ in range(5):x = next(data_iter)xr = torch.from_numpy(x).cuda()# [b]predr = (D(xr))# max log(lossr)lossr = - (predr.mean())# [b, 2]z = torch.randn(batchsz, 2).cuda()# stop gradient on G# [b, 2]xf = G(z).detach()# [b]predf = (D(xf))# min predflossf = (predf.mean())# gradient penaltygp = gradient_penalty(D, xr, xf)loss_D = lossr + lossf + gpoptim_D.zero_grad()loss_D.backward()# for p in D.parameters():# print(p.grad.norm())optim_D.step()# 2. train Generatorz = torch.randn(batchsz, 2).cuda()xf = G(z)predf = (D(xf))# max predfloss_G = - (predf.mean())optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')generate_image(D, G, xr, epoch)print(loss_D.item(), loss_G.item())if __name__ == '__main__':main()1.54.5.參考文章
https://zhuanlan.zhihu.com/p/117529144
https://blog.csdn.net/jizhidexiaoming/article/details/96485095
總結
以上是生活随笔為你收集整理的54_pytorch GAN(生成对抗网络)、Gan代码示例、WGAN代码示例的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 雄伟的地下军队第4自然段写什么?
- 下一篇: 当兵五年还能回去上大学吗?