circlegan_【源码解读】cycleGAN(二) :训练
訓練的代碼見于train.py,首先定義好網絡,兩個生成器A2B, B2A和兩個判別器A, B,以及對應的優化器(優化器的設置保證了只更新生成器或判別器,不會互相影響)
###### Definition of variables #######Networks
netG_A2B =Generator(opt.input_nc, opt.output_nc)
netG_B2A=Generator(opt.output_nc, opt.input_nc)
netD_A=Discriminator(opt.input_nc)
netD_B= Discriminator(opt.output_nc)
#Optimizers & LR schedulers
optimizer_G =torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_A= torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999))
optimizer_D_B= torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999))
然后是數據
#Dataset loader
transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC),
transforms.RandomCrop(opt.size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader= DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True),
batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu)
接著就可以求取損失,反傳梯度,更新網絡,更新網絡的時候首先更新生成器,然后分別更新兩個判別器
生成器:損失函數=身份損失+對抗損失+循環一致損失
###### Generators A2B and B2A ######
optimizer_G.zero_grad()#Identity loss
#G_A2B(B) should equal B if real B is fed
same_B =netG_A2B(real_B)
loss_identity_B= criterion_identity(same_B, real_B)*5.0
#G_B2A(A) should equal A if real A is fed
same_A =netG_B2A(real_A)
loss_identity_A= criterion_identity(same_A, real_A)*5.0
#GAN loss
fake_B =netG_A2B(real_A)
pred_fake=netD_B(fake_B)
loss_GAN_A2B=criterion_GAN(pred_fake, target_real)
fake_A=netG_B2A(real_B)
pred_fake=netD_A(fake_A)
loss_GAN_B2A=criterion_GAN(pred_fake, target_real)#Cycle loss
recovered_A =netG_B2A(fake_B)
loss_cycle_ABA= criterion_cycle(recovered_A, real_A)*10.0recovered_B=netG_A2B(fake_A)
loss_cycle_BAB= criterion_cycle(recovered_B, real_B)*10.0
#Total loss
loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA +loss_cycle_BAB
loss_G.backward()
optimizer_G.step()
判別器A?損失函數= 真實樣本分類損失 + 虛假樣本分類損失
###### Discriminator A ######
optimizer_D_A.zero_grad()#Real loss
pred_real =netD_A(real_A)
loss_D_real=criterion_GAN(pred_real, target_real)#Fake loss
fake_A =fake_A_buffer.push_and_pop(fake_A)
pred_fake=netD_A(fake_A.detach())
loss_D_fake=criterion_GAN(pred_fake, target_fake)#Total loss
loss_D_A = (loss_D_real + loss_D_fake)*0.5loss_D_A.backward()
optimizer_D_A.step()###################################
判別器B損失函數= 真實樣本分類損失 + 虛假樣本分類損失
###### Discriminator B ######
optimizer_D_B.zero_grad()#Real loss
pred_real =netD_B(real_B)
loss_D_real=criterion_GAN(pred_real, target_real)#Fake loss
fake_B =fake_B_buffer.push_and_pop(fake_B)
pred_fake=netD_B(fake_B.detach())
loss_D_fake=criterion_GAN(pred_fake, target_fake)#Total loss
loss_D_B = (loss_D_real + loss_D_fake)*0.5loss_D_B.backward()
optimizer_D_B.step()###################################
可以注意到,判別器損失中,虛假樣本fake_A,fake_B都采用detach()操作,脫離計算圖,這樣判別器的損失進行反向傳播不會對整個網絡計算梯度,避免了不必要的計算
總結
以上是生活随笔為你收集整理的circlegan_【源码解读】cycleGAN(二) :训练的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: bson json c语言,对比平台--
- 下一篇: 安全策略_Spring Security