【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络
「@Author:Runsen」
GAN 是使用兩個神經網絡模型訓練的生成模型。一種模型稱為生成網絡模型,它學習生成新的似是而非的樣本。另一個模型被稱為判別網絡,它學習區分生成的例子和真實的例子。
生成性對抗網絡
2014,蒙特利爾大學的Ian Goodfellow和他的朋友發明了生成性對抗網絡(GAN)。自它出版以來,有許多它的變體和客觀功能來解決它的問題
論文在這里找到.
論文提出了兩種模型:生成模型和判別模型。兩個模型競爭,以產生真實和假的樣本。2016年,Yann LeCun將GANs描述為“過去二十年機器學習中最酷的想法”。
GAN 的大部分研究和應用都集中在計算機視覺領域。
其原因是卷積神經網絡 (CNN) 等深度學習模型在過去 5 到 7 年中在計算機視覺領域取得了巨大成功,例如在具有挑戰性的任務(如對象檢測和人臉識別。
GAN 的典型例子是生成新的逼真的照片,最令人吃驚的是生成照片般逼真的人臉的例子。
在本教程中,我們將實現一個簡單的GAN生成假的MNIST樣本。
import?torch import?torch.nn?as?nn import?torch.optim?as?optim from?torch.utils.data?import?DataLoaderimport?torchvision import?torchvision.datasets?as?datasets import?torchvision.transforms?as?transforms import?torchvision.utils?as?utilsimport?numpy?as?np import?matplotlib.pyplot?as?plt #?CPU?/?GPU?Setting device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu') print(device)??#cuda使用MNIST數據集,具有最小大小的數據集。
它由60000個訓練圖像和10000個測試圖像組成,每個圖像有28*28的大小和一個彩色通道。
#?Define?a?transform? transform?=?transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean?=?(0.5,?),?std?=?(0.5,?)) ])# batch_size是一個前向和后向傳播過程中的圖像數。 batch_size?=?100mnist?=?datasets.MNIST('./data/MNIST',?download?=?True,?train?=?True,?transform?=?transform)mnist_loader?=?DataLoader(dataset?=?mnist,?batch_size?=?batch_size,?shuffle?=?True) #?CPU def?imshow(img,?title):img?=?utils.make_grid(img.cpu().detach())img?=?(img+1)/2npimg?=?img.detach().numpy()plt.imshow(np.transpose(npimg,?(1,?2,?0)))plt.title(title)plt.show() #GPU def?imshow(img,?title):npimg?=?img.detach().numpy()fig?=?plt.figure(figsize?=?(10,?10))plt.imshow(np.transpose(npimg,?(1,?2,?0)))plt.title(title)plt.show()images,?labels?=?iter(mnist_loader).next() imshow(images[0:16,?:,?:],?"MNIST?Images")建立一個GANs模型。一個Generator和Discriminator
GANs由完全連接的層組成。它將從100維高斯分布采樣的噪聲轉換為MNIST圖像。鑒別器網絡也由完全連接的層組成,用于區分輸入數據是真是假。
class?Generator(nn.Module):def?__init__(self):super(Generator,?self).__init__()latent_size?=?100output?=?28*28self.main?=?nn.Sequential(nn.Linear(latent_size,?128),nn.ReLU(inplace=True),nn.Linear(128,?256),nn.ReLU(inplace=True),nn.Linear(256,?512),nn.ReLU(inplace=True),nn.Linear(512,?output),nn.Tanh())def?forward(self,?x):out?=?self.main(x)out?=?out.view(-1,?1,?28,?28)return?outclass?Discriminator(nn.Module):def?__init__(self):super(Discriminator,?self).__init__()n_features?=?28?*?28n_out?=?1self.main?=?nn.Sequential(nn.Linear(n_features,?512),nn.ReLU(inplace=True),nn.Linear(512,?256),nn.ReLU(inplace=True),nn.Linear(256,?128),nn.ReLU(inplace=True),nn.Linear(128,?64),nn.ReLU(inplace=True),nn.Linear(64,?n_out),nn.Sigmoid()????????)def?forward(self,?x):x?=?x.view(-1,?28*28)out?=?self.main(x)return?outG?=?Generator().to(device) D?=?Discriminator().to(device)生成性對抗網絡訓練過程的損失函數是二進制交叉熵損失,由torch.nn.BCELoss實現。
這兩種模型都使用torch.optim.Adam作為優化工具,學習率設置為0.002。
#?Objective?Function criterion?=?nn.BCELoss()#?Optimizer G_optimizer?=?optim.Adam(G.parameters(),?lr?=?0.0002) D_optimizer?=?optim.Adam(D.parameters(),?lr?=?0.0002)#?Constants noise_dim?=?100 num_epochs?=?50 total_batch?=?len(mnist_loader)#?Lists G_losses?=?[] D_losses?=?[]#?Noise sample_size?=?16 fixed_noise?=?torch.randn(sample_size,?noise_dim).to(device)#?Train for?epoch?in?range(num_epochs):for?i,?(images,?labels)?in?enumerate(mnist_loader):#?Images?#images?=?images.reshape(batch_size,?-1).float().to(device)#?Labels?#ones?=?torch.ones(batch_size,?1).to(device)zeros?=?torch.zeros(batch_size,?1).to(device)#?Noise?#noise?=?torch.randn(batch_size,?noise_dim).to(device)#?Initialize?OptimizersD_optimizer.zero_grad()G_optimizer.zero_grad()########################?Train?Discriminator?#########################?Forward?Images?#prob_real?=?D(images)D_real_loss?=?criterion(prob_real,?ones)#?Generate?Samples?#fake_images?=?G(noise)prob_fake?=?D(fake_images)#?Forward?Fake?Samples?and?Calculate?Discriminator?Loss?#D_fake_loss?=?criterion(prob_fake,?zeros)D_loss?=?(D_real_loss?+?D_fake_loss).mean()#?Back?Propagation?and?UpdateD_loss.backward()D_optimizer.step()####################?Train?Generator?####################fake_images?=?G(noise)prob_fake?=?D(fake_images)#?According?to?the?p?3?in?paper,#?early?in?learning,?when?G?is?very?poor,?D?can?reject?samples?from?G.#?In?this?case,?log(1-D(G(z)))?saturates.?#?thus,?train?G?to?maximiaze?log(D(G(z)))?instead?of?minimizing?log(1-D(G(z)))G_loss?=?criterion(prob_fake,?ones)#?Back?Propagation?and?UpdateG_loss.backward()G_optimizer.step()#?Save?Losses?for?Plotting?LaterG_losses.append(G_loss.item())D_losses.append(D_loss.item())#?Print?Statistics?#if?(i?+?1)?%?100?==?0:print("Epoch?[%d/%d]?Iter?[%d/%d],?D_Loss:?%.4f?G_Loss:?%.4f"%(epoch+1,?num_epochs,?i+1,?total_batch,?D_loss.item(),?G_loss.item()))#?Generate?Samples?#if?epoch?%?1?==?0:fake_samples?=?G(fixed_noise)imshow(fake_samples,?"Generated?MNIST?Images")#?Save?Model?Weights?for?Digit?Generation torch.save(G.state_dict(),?'./data/GAN.pkl') plt.figure(figsize?=?(8,?6)) plt.title("Generator?and?Discriminator?Loss?During?Training") plt.plot(G_losses,?label="Generator") plt.plot(D_losses,?label="Discriminator") plt.xlabel("Iterations") plt.ylabel("Losses") plt.legend() plt.show() sample_size?=?64 noise_dim?=?100noise?=?torch.randn(sample_size,?noise_dim).to(device)G.load_state_dict(torch.load('GAN.pkl')) fake_samples?=?G(fixed_noise) imshow(fake_samples,?"Generated?MNIST?Images")GAN生成性對抗網絡的運用
將語義圖像翻譯成城市景觀和建筑物的照片。
將衛星照片翻譯成地圖。
從白天到晚上的照片翻譯。
將黑白照片翻譯成彩色。
- 論文在這里找到:https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
- 上述代碼的論文:https://arxiv.org/abs/1511.06434
- 上述代碼:https://github.com/yihui-he/GAN-MNIST
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯黃海廣老師《機器學習課程》課件合集 本站qq群851320808,加入微信群請掃碼:總結
以上是生活随笔為你收集整理的【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【Python】简约而不简单|值得收藏的
- 下一篇: XML解析-Dom4j的DOM解析方式更