DCGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)
生活随笔
收集整理的這篇文章主要介紹了
DCGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
代碼下載地址下載地址https://www.lanzouw.com/ipl8Yo37qxihttps://www.lanzouw.com/ipl8Yo37qxi
Anime數據請在Anime Face Dataset | Kaggle下載,其他數據都是pytorch自帶,在線下載即可
下面的代碼時是用DCGAN生成#選擇cifar10, cifar100, mnist, fashion_mnist,STL10,Anime圖片
目錄情況:
DCGAN3的目錄情況
generated_fake目錄:
有的模型已經訓練,有的沒有,如果提示模型文件不存在,請將resume=False
import torch,torchvision import torch.nn as nn import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np#rusume是否使用預訓練模型繼續訓練,問號處輸入模型的編號 resume = True #是繼續訓練,否重新訓練 datasets = 'Anime' #選擇cifar10, cifar100, mnist, fashion_mnist,STL10,Animeif datasets == 'cifar10' or datasets=='cifar100' or datasets=='STL10'or datasets=='Anime':nc = 3 #圖片的通道數 elif datasets == 'mnist' or datasets== 'fashion_mnist':nc = 1 else:print('數據集選擇錯誤')batch_size = 128 nz = 100 #噪聲向量的維度 ndf = 64 ngf = 64 real_label = 1 fake_label = 0 start_epoch = 0#定義模型#生成器 #(N,nz, 1,1) netG = nn.Sequential(nn.ConvTranspose2d(nz, ngf*8,4, 1,0, bias=False), nn.BatchNorm2d(ngf*8), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1, bias=False), nn.BatchNorm2d(ngf*4), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*4, ngf*4,4,2, 1,bias=False), nn.BatchNorm2d(ngf*4), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*4, ngf*2,4,2, 1,bias=False), nn.BatchNorm2d(ngf*2), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*2, ngf*2,4,2, 1,bias=False), nn.BatchNorm2d(ngf*2), nn.LeakyReLU(0.2,inplace=True),nn.ConvTranspose2d(ngf*2, nc,4,2,1, bias=False), nn.Tanh() #(N,nz, 128,128))#判別器 #(N,nc, 128,128) netD = nn.Sequential(nn.Conv2d(nc, ndf*2, 4,2,1, bias=False), nn.BatchNorm2d(ndf*2),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*2,ndf*2, 4,2,1, bias=False), nn.BatchNorm2d(ndf*2),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*2, ndf*4,4,2,1,bias=False),nn.BatchNorm2d(ndf*4),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*4,ndf*4,4,2,1, bias=False), nn.BatchNorm2d(ndf*4),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*4,ndf*8,4,2,1, bias=False), nn.BatchNorm2d(ndf*8),nn.LeakyReLU(0.2,inplace=True),nn.Conv2d(ndf*8,1, 4,1,0, bias=False), #(N,1,1,1)nn.Flatten(), #(N,1)nn.Sigmoid())# custom weights initialization called on netG and netD def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find('BatchNorm') != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)netD.apply(weights_init) netG.apply(weights_init)#加載數據集 apply_transform1 = transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])apply_transform2 = transforms.Compose([transforms.Resize(128),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),])if datasets == 'cifar100':train_dataset = torchvision.datasets.CIFAR100(root='../data/cifar100', train=False, download=True,transform=apply_transform1) elif datasets == 'cifar10':train_dataset = torchvision.datasets.CIFAR10(root='../data/cifar10', train=False, download=True,transform=apply_transform1) elif datasets == 'STL10':train_dataset = torchvision.datasets.STL10(root='../data/STL10', split='train', download=True,transform=apply_transform1) elif datasets == 'mnist':train_dataset = torchvision.datasets.MNIST(root='../data/mnist', train=False, download=True,transform=apply_transform2) elif datasets == 'fashion_mnist':train_dataset = torchvision.datasets.FashionMNIST(root='../data/fashion_mnist', train=False, download=True,transform=apply_transform2) elif datasets == 'Anime':train_dataset = torchvision.datasets.ImageFolder(root='../data/Anime',transform=apply_transform1)else:print('數據集不存在')train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4)#定義損失函數 criterion = torch.nn.BCELoss() device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')# setup optimizer optimizerD = torch.optim.Adam(netD.parameters(), lr=0.0002,betas=(0.5, 0.999)) optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0002,betas=(0.5,0.999))#顯示16張圖片if datasets=='Anime':image,label = next(iter(train_loader))image = (image*0.5+0.5)[:16] elif datasets=='mnist' or datasets=='fashion_mnist':image = next(iter(train_loader))[0]image = image[:16]*0.5+0.5elif datasets=='STL10' :image = torch.Tensor(train_dataset.data[:16]/255) else:image = torch.Tensor(train_dataset.data[:16]/255).permute(0,3,1,2) plt.imshow(torchvision.utils.make_grid(image,nrow=4).permute(1,2,0))#訓練和保存模型 #如果繼續訓練,就加載預訓練模型 if resume:print('==> Resuming from checkpoint..')checkpoint = torch.load('./checkpoint/GAN_%s_best.pth'%datasets)netG.load_state_dict(checkpoint['net_G']) netD.load_state_dict(checkpoint['net_D'])start_epoch = checkpoint['start_epoch'] print('netG:','\n',netG) print('netD:','\n',netD)print('training on: ',device, ' start_epoch',start_epoch)netD, netG = netD.to(device), netG.to(device) #固定生成器,訓練判別器 for epoch in range(start_epoch,300):for batch, (data, target) in enumerate(train_loader):batch_size = data.size(0)label = torch.full((batch_size,1), real_label).to(device) #(1)訓練判別器 #training real datanetD.zero_grad()data = data.to(device)output = netD(data)loss_D1 = criterion(output, label)loss_D1.backward()#training fake datanoise_z = torch.randn(batch_size, nz, 1, 1, device=device)fake_data = netG(noise_z)label = torch.full((batch_size,1), fake_label).to(device)output = netD(fake_data.detach())loss_D2 = criterion(output, label)loss_D2.backward()#更新判別器optimizerD.step()#(2)訓練生成器netG.zero_grad()label = torch.full((batch_size,1), real_label).to(device)output = netD(fake_data)lossG = criterion(output, label)lossG.backward()#更新生成器optimizerG.step()if batch %100==0:print('epoch: %4d, batch: %4d, discriminator loss: %.4f, generator loss: %.4f'%(epoch, batch, loss_D1.item()+loss_D2.item(), lossG.item()))#每2個epoch保存圖片if epoch%2==0:#如果是單通道圖片,那么就轉成三通道進行保存if nc ==1:fake_data=torch.cat((fake_data,fake_data,fake_data),dim=1) #fake_data(N,1,H,W)->(N,3,H,W)#保存圖片data = fake_data.detach().cpu().permute(0,2,3,1)data = np.array(data)#保存單張圖片,將圖片歸一化到(0,1)data = (data*0.5+0.5)plt.imsave('./generated_fake/%s/epoch_%d.png'%(datasets,epoch), data[0])torchvision.utils.save_image(fake_data[:16], fp='./generated_fake/%s/epoch%d_grid.png'%(datasets,epoch),nrow=4,normalize=True)#保存模型 state = {'net_G': netG.state_dict(),'net_D': netD.state_dict(),'start_epoch':epoch+1}torch.save(state, './checkpoint/GAN_%s_best.pth'%(datasets))torch.save(state, './checkpoint/GAN_%s_best_copy.pth'%(datasets))實驗結果:
總結
以上是生活随笔為你收集整理的DCGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 制作pytorch数据集
- 下一篇: CGAN生成cifar10, cifar