CGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)
生活随笔
收集整理的這篇文章主要介紹了
CGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
完整代碼:代碼地址https://www.lanzouw.com/iVadvo386ofhttps://www.lanzouw.com/iVadvo386of
CGAN比DCGAN更進一步,利用標簽信息可以生成指定標簽的數據。
DCGAN的代碼:DCGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime圖片(pytorch)_stay_zezo的博客-CSDN博客
下面是完整的CGAN的代碼,目錄請對比上面的DCGAN
import torch,torchvision import torch.nn as nn import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np from sklearn.preprocessing import LabelBinarizer import random,numpy.random#設置隨機種子, numpy, pytorch, python隨機種子 def seed_torch(seed=2021):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.backends.cudnn.deterministic = True seed_torch()#rusume是否使用預訓練模型繼續訓練,問號處輸入模型的編號 resume = True #是繼續訓練,否重新訓練 datasets = 'mnist' #選擇cifar10, mnist, fashion_mnist,STL10,Animeif datasets == 'cifar10' or datasets=='STL10'or datasets=='Anime':nc = 3 #圖片的通道數 elif datasets == 'mnist' or datasets== 'fashion_mnist':nc = 1 else:print('數據集選擇錯誤')#類別數 n_classes = 10#控制生成器生成指定標簽的圖片 target_label=4#訓練批次數 batch_size = 128#噪聲向量的維度 nz = 100 #判別器的深度 ndf = 64 #生成器的深度 ngf = 64#真實標簽 real_label = 1.0 #假標簽 fake_label = 0.0 start_epoch = 0#模型#生成器 #(N,nz, 1,1) netG = nn.Sequential(nn.ConvTranspose2d(nz+n_classes, ngf*8,4, 1,0, bias=False), nn.BatchNorm2d(ngf*8), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1, bias=False), nn.BatchNorm2d(ngf*4), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*4, ngf*4,4,2, 1,bias=False), nn.BatchNorm2d(ngf*4), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*4, ngf*2,4,2, 1,bias=False), nn.BatchNorm2d(ngf*2), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*2, ngf*2,4,2, 1,bias=False), nn.BatchNorm2d(ngf*2), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*2, nc,4,2,1, bias=False), nn.Tanh() #(N,nc, 128,128))#判別器 #(N,nc, 128,128) netD = nn.Sequential(nn.Conv2d(nc+n_classes, ndf*2, 4,2,1, bias=False), nn.BatchNorm2d(ndf*2),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*2,ndf*2, 4,2,1, bias=False), nn.BatchNorm2d(ndf*2),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*2, ndf*4,4,2,1,bias=False),nn.BatchNorm2d(ndf*4),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*4,ndf*4,4,2,1, bias=False), nn.BatchNorm2d(ndf*4),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*4,ndf*8,4,2,1, bias=False), nn.BatchNorm2d(ndf*8),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*8,1, 4,1,0, bias=False), #(N,1,1,1)nn.Flatten(), #(N,1)nn.Sigmoid())# custom weights initialization called on netG and netD def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find('BatchNorm') != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)netD.apply(weights_init) netG.apply(weights_init)#加載數據集 apply_transform1 = transforms.Compose([transforms.Resize((128,128)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])apply_transform2 = transforms.Compose([transforms.Resize((128,128)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),])if datasets == 'cifar100':train_dataset = torchvision.datasets.CIFAR100(root='../data/cifar100', train=False, download=True,transform=apply_transform1) elif datasets == 'cifar10':train_dataset = torchvision.datasets.CIFAR10(root='../data/cifar10', train=False, download=True,transform=apply_transform1) elif datasets == 'STL10':train_dataset = torchvision.datasets.STL10(root='../data/STL10', split='train', download=True,transform=apply_transform1) elif datasets == 'mnist':train_dataset = torchvision.datasets.MNIST(root='../data/mnist', train=False, download=True,transform=apply_transform2) elif datasets == 'fashion_mnist':train_dataset = torchvision.datasets.FashionMNIST(root='../data/fashion_mnist', train=False, download=True,transform=apply_transform2) elif datasets == 'Anime':train_dataset = torchvision.datasets.ImageFolder(root='../data/Anime',transform=apply_transform1) else:print('數據集不存在')train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)#定義損失函數 criterion = torch.nn.BCELoss() device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')# setup optimizer optimizerD = torch.optim.Adam(netD.parameters(), lr=0.0002,betas=(0.5, 0.999)) optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0002,betas=(0.5,0.999))#顯示16張圖片if datasets=='Anime':image,label = next(iter(train_loader))image = (image*0.5+0.5)[:16] elif datasets=='mnist' or datasets=='fashion_mnist':image = next(iter(train_loader))[0]image = image[:16]*0.5+0.5elif datasets=='STL10' :image = torch.Tensor(train_dataset.data[:16]/255) else:image = torch.Tensor(train_dataset.data[:16]/255).permute(0,3,1,2) plt.imshow(torchvision.utils.make_grid(image,nrow=4).permute(1,2,0))lb = LabelBinarizer() lb.fit(list(range(0,n_classes)))#將標簽進行one-hot編碼 def to_categrical(y: torch.FloatTensor):y_one_hot = lb.transform(y.cpu())floatTensor = torch.FloatTensor(y_one_hot)return floatTensor.to(device)#樣本和one-hot標簽進行連接,以此作為條件生成 def concanate_data_label(data, y): #data (N,nc, 128,128)y_one_hot = to_categrical(y) #(N,1)->(N,n_classes)con = torch.cat((data, y_one_hot), 1)return con#如果繼續訓練,就加載預訓練模型 if resume:print('==> Resuming from checkpoint..')checkpoint = torch.load('./checkpoint/GAN_%s_best.pth'%datasets)netG.load_state_dict(checkpoint['net_G']) netD.load_state_dict(checkpoint['net_D'])start_epoch = checkpoint['start_epoch'] print('netG:','\n',netG) print('netD:','\n',netD)print('training on: ',device, ' start_epoch',start_epoch)netD, netG = netD.to(device), netG.to(device) #固定生成器,訓練判別器 for epoch in range(start_epoch,500):for batch, (data, target) in enumerate(train_loader): # if epoch%2==0 and batch==0: # torchvision.utils.save_image(data[:16], filename='./generated_fake/%s/源epoch_%d_grid.png'%(datasets,epoch),nrow=4,normalize=True)data = data.to(device)target = target.to(device)#拼接真實數據和標簽target1 = to_categrical(target).unsqueeze(2).unsqueeze(3).float() #加到噪聲上target2 = target1.repeat(1, 1, data.size(2), data.size(3)) #加到數據上data = torch.cat((data, target2),dim=1) #將標簽與數據拼接 (N,nc,128,128),(N,n_classes, 128,128)->(N,nc+nc_classes,128,128)label = torch.full((data.size(0),1), real_label).to(device)#(1)訓練判別器 #training real datanetD.zero_grad()output = netD(data)loss_D1 = criterion(output, label)loss_D1.backward()#training fake data,拼接噪聲和標簽noise_z = torch.randn(data.size(0), nz, 1, 1).to(device)noise_z = torch.cat((noise_z, target1),dim=1) #(N,nz+n_classes,1,1)#拼接假數據和標簽fake_data = netG(noise_z)fake_data = torch.cat((fake_data,target2),dim=1) #(N,nc+n_classes,128,128)label = torch.full((data.size(0),1), fake_label).to(device)output = netD(fake_data.detach())loss_D2 = criterion(output, label)loss_D2.backward()#更新判別器optimizerD.step()#(2)訓練生成器netG.zero_grad()label = torch.full((data.size(0),1), real_label).to(device)output = netD(fake_data.to(device))lossG = criterion(output, label)lossG.backward()#更新生成器optimizerG.step()if batch %10==0:print('epoch: %4d, batch: %4d, discriminator loss: %.4f, generator loss: %.4f'%(epoch, batch, loss_D1.item()+loss_D2.item(), lossG.item()))#每2個epoch保存圖片if epoch%2==0 and batch==0:#生成指定target_label的圖片noise_z1 = torch.randn(data.size(0), nz, 1, 1).to(device)target3 = to_categrical(torch.full((data.size(0),1), target_label)).unsqueeze(2).unsqueeze(3).float() #加到噪聲上noise_z = torch.cat((noise_z1, target3),dim=1) #(N,nz+n_classes,1,1)fake_data = netG(noise_z.to(device))#如果是單通道圖片,那么就轉成三通道進行保存if nc ==1:fake_data=torch.cat((fake_data,fake_data,fake_data),dim=1) #fake_data(N,1,H,W)->(N,3,H,W)#保存圖片data = fake_data.detach().cpu().permute(0,2,3,1)data = np.array(data)#保存單張圖片,將數據還原data = (data*0.5+0.5)plt.imsave('./generated_fake/%s/epoch_%d.png'%(datasets,epoch), data[0])torchvision.utils.save_image(fake_data[:16]*0.5+0.5, filename='./generated_fake/%s/epoch_%d_grid.png'%(datasets,epoch),nrow=4,normalize=True)#保存模型 state = {'net_G': netG.state_dict(),'net_D': netD.state_dict(),'start_epoch':epoch+1}torch.save(state, './checkpoint/GAN_%s_best.pth'%(datasets))torch.save(state, './checkpoint/GAN_%s_best_copy.pth'%(datasets))實驗結果:
總結
以上是生活随笔為你收集整理的CGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: DCGAN生成cifar10, cifa
- 下一篇: Java控制台如何输入一行、多行?