用MXNet实现mnist的生成对抗网络(GAN)
用MXNet實現mnist的生成對抗網絡(GAN)
生成式對抗網絡(Generative Adversarial Network,簡稱GAN)由一個生成網絡與一個判別網絡組成。生成網絡從潛在空間(latent space)中隨機采樣作為輸入,其輸出結果需要盡量模仿訓練集中的真實樣本。判別網絡的輸入則為真實樣本或生成網絡的輸出,其目的是將生成網絡的輸出從真實樣本中盡可能分辨出來。而生成網絡則要盡可能地欺騙判別網絡。兩個網絡相互對抗、不斷調整參數,最終目的是使判別網絡無法判斷生成網絡的輸出結果是否真實。從數據的分布來看就是使得生成的數據分布\(P_z(z)\)與原來的數據\(P_{data}(x)\)十分接近,理想的情況下為\(P_z(z)=P_{data}(x)\)。本文給出了GAN的Loss函數、說明GAN的訓練原理,再結合最簡單的例子mnist,用MXNet來實現GAN。
GAN的基本概念
在一樣樣本中加入一些精心編制的噪聲,會使得原來的分類器失效。圖1是一個廣為流傳的示例,左邊的分類器得到的是熊貓而右邊被分類為了長臂猿。
圖1 誤分類的示例為什么會有這樣的結果?圖像分類器本質上是多維空間中的決策邊界,當訓練的樣本不足時,可能會使得分類器過擬合。當向原樣本中加入一些L2范數很小的噪聲時,人類的視覺是無法分別這些細微的差別,所以依然會認為和原樣本的分類沒什么區別。但對過擬合的分類器來說,輸入樣本的小偏差可能使得最后的決策點越過了原來的決策邊界,進入到其它分類中了。這就導致了錯誤的分類。
對于生成網絡設為G,\(G(Z)\)為生成的對抗樣本,理想條件下\(G(z)\)隨機生成的樣本分布與真實樣本分布是一樣。對于判別網絡設為D,\(D(x)\)為判別樣本是真實的概率,理想條件下對真實樣本有\(G(x)=1\),對生成樣本有\(D(G(z))=0\)。為了達到效果,設計了如圖2所示的網絡結構:
圖2 GAN的網絡結構Loss函數如下:
\[ V(G,D)=E_{x-p_{data}(x)}[\log(D(x))] + E_{z-p_{z}(z)}[1-\log(D(G(z)))] \tag{1.1} \]
這個Loss函數的優化方法與EM算法的思想是相似的:在G是固定的情況下,判別網絡D的精確率越高,那么V就越大;在D固定的條件下,生成網絡G的生成的樣本越像實際樣本,那么V就越小。所有V(G,D)進行了極小極大化博弈:
\[ \min_G \max_D V(G,D)=E_{x-p_{data}(x)}[\log(D(x))] + E_{z-p_{z}(z)}[1-\log(D(G(z)))] \tag{1.2} \]
實現mnist的GAN
MXNet的源碼給出了mnsit的GAN實現(見dcgan.py),但是沒有給出詳細的說明,我在這里詳細解釋下,源文件在裝了相關的python包之后是能正確運行的。DCGAN是指Deep Convolution Generative Adversarial Netword(深度卷積生成式對抗網格)。
mnist的網絡相對來說比較簡單,如圖所示:
圖3 D是判別式網絡,G是生成式網絡,可以看到兩個網絡輸出的數據大致成反向對稱生成網絡G的結構與判別網絡D的結果是反向對稱的(雖然兩個網絡的開頭或者結尾有所不同,但這是為了與結果相對應),這里有一個很重要但被很多文章忽略的假設:判別網絡從潛在空間(latent space)是可逆的。不是說從最后的結果是可逆的,但從原始圖片映射到潛在空間這個過程(比如說從全連接層的n(n一般比較大)維向量)是可逆的,這里說的可逆不是嚴格意義上的反函數,而是從視覺判別結果上區別不大,比如說在G與D理想的情況下數字9通過判別網絡得到一個100維的向量,再將這個100維向量通過生成網絡G得到一張圖片,這張圖片在人類看來也是9。
代碼實現如下:
def make_dcgan_sym(ngf, ndf, nc, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12):BatchNorm = mx.sym.BatchNorm# 生成網絡G# 輸入生成網絡G的變量,這個是潛在空間rand = mx.sym.Variable('rand')g1 = mx.sym.Deconvolution(rand, name='g1', kernel=(4,4), num_filter=ngf*8, no_bias=no_bias)gbn1 = BatchNorm(g1, name='gbn1', fix_gamma=fix_gamma, eps=eps)gact1 = mx.sym.Activation(gbn1, name='gact1', act_type='relu')g2 = mx.sym.Deconvolution(gact1, name='g2', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf*4, no_bias=no_bias)gbn2 = BatchNorm(g2, name='gbn2', fix_gamma=fix_gamma, eps=eps)gact2 = mx.sym.Activation(gbn2, name='gact2', act_type='relu')g3 = mx.sym.Deconvolution(gact2, name='g3', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf*2, no_bias=no_bias)gbn3 = BatchNorm(g3, name='gbn3', fix_gamma=fix_gamma, eps=eps)gact3 = mx.sym.Activation(gbn3, name='gact3', act_type='relu')g4 = mx.sym.Deconvolution(gact3, name='g4', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ngf, no_bias=no_bias)gbn4 = BatchNorm(g4, name='gbn4', fix_gamma=fix_gamma, eps=eps)gact4 = mx.sym.Activation(gbn4, name='gact4', act_type='relu')g5 = mx.sym.Deconvolution(gact4, name='g5', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=nc, no_bias=no_bias)# 生成網絡G最后得到一張相片gout = mx.sym.Activation(g5, name='gact5', act_type='tanh')# 判別網絡D,這里里的結構與一般的分類網絡區別不大data = mx.sym.Variable('data')label = mx.sym.Variable('label')d1 = mx.sym.Convolution(data, name='d1', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf, no_bias=no_bias)dact1 = mx.sym.LeakyReLU(d1, name='dact1', act_type='leaky', slope=0.2)d2 = mx.sym.Convolution(dact1, name='d2', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*2, no_bias=no_bias)dbn2 = BatchNorm(d2, name='dbn2', fix_gamma=fix_gamma, eps=eps)dact2 = mx.sym.LeakyReLU(dbn2, name='dact2', act_type='leaky', slope=0.2)d3 = mx.sym.Convolution(dact2, name='d3', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*4, no_bias=no_bias)dbn3 = BatchNorm(d3, name='dbn3', fix_gamma=fix_gamma, eps=eps)dact3 = mx.sym.LeakyReLU(dbn3, name='dact3', act_type='leaky', slope=0.2)d4 = mx.sym.Convolution(dact3, name='d4', kernel=(4,4), stride=(2,2), pad=(1,1), num_filter=ndf*8, no_bias=no_bias)dbn4 = BatchNorm(d4, name='dbn4', fix_gamma=fix_gamma, eps=eps)dact4 = mx.sym.LeakyReLU(dbn4, name='dact4', act_type='leaky', slope=0.2)d5 = mx.sym.Convolution(dact4, name='d5', kernel=(4,4), num_filter=1, no_bias=no_bias)d5 = mx.sym.Flatten(d5)# 用邏輯回歸計算最后的lossdloss = mx.sym.LogisticRegressionOutput(data=d5, label=label, name='dloss')# 返回這G與D這兩個網絡return gout, dloss在訓練的過程中,所有的原樣本的label為1,生成網絡G生成的樣本的label為0,用這樣來區別原樣本與生成的對抗樣本。生成網絡輸入的潛在空間樣本是100維的,訓練過程如下:
- 用生成網絡G生成對抗樣本gout
- 對抗樣本的label設為0,因為要先用這個訓練判別網絡D
- 用gout來訓練判別網絡D,得到梯度,但不更新
- 對原樣本的label設為1,再用之來訓練判別網絡D
- 得到梯度后合入gout得到的梯度,更新D的參數
- 下面的過程是為了得到生成網絡G的loss
- 設gout的label為1,因為生成網絡G的目標就是要生成label為1的樣本,所以訓練G的label為1。反之,如果訓練D,為了區別原樣本與生成樣本所以label為0。
- 用判別網絡D來得輸入的梯度dgout,這個梯度就是生成網絡G的loss。
- 用這個loss反向傳播生成網絡G,并更新參數。
這里面的關鍵就是用判別網絡D來得到生成網絡G的loss,之所以可以這樣,是因為這兩個網絡是可逆的。訓練的代碼如下:
if __name__ == '__main__':logging.basicConfig(level=logging.DEBUG)# =============setting============dataset = 'mnist'imgnet_path = './train.rec'ndf = 64ngf = 64nc = 3batch_size = 64Z = 100lr = 0.0002beta1 = 0.5ctx = mx.gpu(0)check_point = FalsesymG, symD = make_dcgan_sym(ngf, ndf, nc)#mx.viz.plot_network(symG, shape={'rand': (batch_size, 100, 1, 1)}).view()#mx.viz.plot_network(symD, shape={'data': (batch_size, nc, 64, 64)}).view()# ==============data==============if dataset == 'mnist':X_train, X_test = get_mnist()train_iter = mx.io.NDArrayIter(X_train, batch_size=batch_size)elif dataset == 'imagenet':train_iter = ImagenetIter(imgnet_path, batch_size, (3, 64, 64))rand_iter = RandIter(batch_size, Z)label = mx.nd.zeros((batch_size,), ctx=ctx)# =============module G=============modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctx)modG.bind(data_shapes=rand_iter.provide_data)modG.init_params(initializer=mx.init.Normal(0.02))modG.init_optimizer(optimizer='adam',optimizer_params={'learning_rate': lr,'wd': 0.,'beta1': beta1,})mods = [modG]# =============module D=============modD = mx.mod.Module(symbol=symD, data_names=('data',), label_names=('label',), context=ctx)modD.bind(data_shapes=train_iter.provide_data,label_shapes=[('label', (batch_size,))],inputs_need_grad=True)modD.init_params(initializer=mx.init.Normal(0.02))modD.init_optimizer(optimizer='adam',optimizer_params={'learning_rate': lr,'wd': 0.,'beta1': beta1,})mods.append(modD)# ============printing==============def norm_stat(d):return mx.nd.norm(d)/np.sqrt(d.size)mon = mx.mon.Monitor(10, norm_stat, pattern=".*output|d1_backward_data", sort=True)mon = Noneif mon is not None:for mod in mods:passdef facc(label, pred):pred = pred.ravel()label = label.ravel()return ((pred > 0.5) == label).mean()def fentropy(label, pred):pred = pred.ravel()label = label.ravel()return -(label*np.log(pred+1e-12) + (1.-label)*np.log(1.-pred+1e-12)).mean()mG = mx.metric.CustomMetric(fentropy)mD = mx.metric.CustomMetric(fentropy)mACC = mx.metric.CustomMetric(facc)print('Training...')stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')# =============train===============for epoch in range(100):train_iter.reset()for t, batch in enumerate(train_iter):rbatch = rand_iter.next()if mon is not None:mon.tic()# 首先生成對抗樣本modG.forward(rbatch, is_train=True)outG = modG.get_outputs()# update discriminator on fake# 這里的負樣本label為0,正樣本label為1,不像普遍的mnist一樣。那么modG就想生成樣本label為1的,modD要將modG生成的數據判定為0# train_iter(真實樣本)中的數據判定為1。label[:] = 0modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)modD.backward()#modD.update()# 先Copy得到的對抗樣本的梯度,要注意是復制不是引用。gradD = [[grad.copyto(grad.context) for grad in grads] for grads in modD._exec_group.grad_arrays]modD.update_metric(mD, [label])modD.update_metric(mACC, [label])# update discriminator on real# 對真實樣本的數據訓練label[:] = 1batch.label = [label]modD.forward(batch, is_train=True)modD.backward()# 對抗樣本與真實樣本的梯度合到一起建行梯度更新for gradsr, gradsf in zip(modD._exec_group.grad_arrays, gradD):for gradr, gradf in zip(gradsr, gradsf):gradr += gradfmodD.update()modD.update_metric(mD, [label])modD.update_metric(mACC, [label])# update generator# 更新modG的參數,這里要注意的是,modG想要生成的樣本label是1的,所以在modD中用了這個label,就是想生成的樣本向label=1靠近。# 前向和向后生成輸入數據的梯度diffDlabel[:] = 1modD.forward(mx.io.DataBatch(outG, [label]), is_train=True)modD.backward()diffD = modD.get_input_grads()# diffD就是modG的loss產生的梯度,用它來向后傳播并更新參數。modG.backward(diffD)modG.update()mG.update([label], modD.get_outputs())if mon is not None:mon.toc_print()t += 1if t % 10 == 0:print('epoch:', epoch, 'iter:', t, 'metric:', mACC.get(), mG.get(), mD.get())mACC.reset()mG.reset()mD.reset()visual('gout', outG[0].asnumpy())diff = diffD[0].asnumpy()diff = (diff - diff.mean())/diff.std()visual('diff', diff)visual('data', batch.data[0].asnumpy())if check_point:print('Saving...')modG.save_params('%s_G_%s-%04d.params'%(dataset, stamp, epoch))modD.save_params('%s_D_%s-%04d.params'%(dataset, stamp, epoch))訓練的結果部分結果如下,gout是生成的樣本,data是原樣本,diff是它們的差。可以從后面生成的gout中看到,結果缺少一些數字,比如2、3等,這是因為我們沒有對各個數字的潛在空間進行生成樣本而是用統一的空間,這個統一的空間中對應的數字可能沒有2、3等或者說它們點的比例相對來說比較小,樣例用到的空間只是保證生成樣本是數字,但并不保證每個數字都會有,如果我保證生成每個數字的樣本,那么得重新設計程序,但原理和例程相差不大。
圖4 輸出的圖像結果:data是原始數據,gout是G生成的對搞樣本,diff是兩者的差。過程打印的輸出如下:
epoch: 99 iter: 930 metric: ('facc', 1.0) ('fentropy', 8.3449375152587884) ('fentropy', 0.00077932097192388026)【防止爬蟲轉載而導致的格式問題——鏈接】:
http://www.cnblogs.com/heguanyou/p/7642608.html
轉載于:https://www.cnblogs.com/heguanyou/p/7642608.html
總結
以上是生活随笔為你收集整理的用MXNet实现mnist的生成对抗网络(GAN)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ES6精华: 解构运算符 扩展运算符
- 下一篇: 用VB.NET(Visual Basic