谷歌开源的 GAN 库--TFGAN
本文大約 8000 字,閱讀大約需要 12 分鐘
第一次翻譯,限于英語水平,可能不少地方翻譯不準確,請見諒!
最近谷歌開源了一個基于 TensorFlow 的庫–TFGAN,方便開發者快速上手 GAN 的訓練,其 Github 地址如下:
https://github.com/tensorflow/models/tree/master/research/gan
原文網址:Generative Adversarial Networks: Google open sources TensorFlow-GAN (TFGAN)
如果你玩過波斯王子,那你應該知道你需要保護自己不被”影子“所殺掉,但這也是一個矛盾:如果你殺死“影子”,那游戲就結束了;但你不做任何事情,那么游戲也會輸掉。
盡管生成對抗網絡(GAN)有不少優點,但它也面臨著相似的區分問題。大部分支持 GAN 的深度學習專業也是非常謹慎的支持它,并指出它確實存在穩定性的問題。
GAN 的這個問題也可以稱做整體收斂性問題。盡管判別器 D 和 生成器 D 相互競爭博弈,但同時也相互依賴對方來達到有效的訓練。如果其中一方訓練得很差,那整個系統也會很差(這也是之前提到的梯度消失或者模式奔潰問題)。并且你也需要確保他們不會訓練太過度,造成另一方無法訓練了。因此,波斯王子是一個很有趣的概念。
首先,神經網絡的提出就是為了模仿人類的大腦(盡管是人為的)。它們也已經在物體識別和自然語言處理方面取得成功。但是,想要在思考和行為上與人類一致,這還有非常大的差距。
那么是什么讓 GANs 成為機器學習領域一個熱門話題呢?因為它不僅只是一個相對新的結構,它更加是一個比之前其他模型都能更加準確的對真實數據建模,可以說是深度學習的一個革命性的變化。
最后,它是一個同時訓練兩個獨立的網絡的新模型,這兩個網絡分別是判別器和生成器。這樣一個非監督神經網絡卻能比其他傳統網絡得到更好性能的結果。
但目前事實是我們對 GANs 的研究還只是非常淺層,仍然有著很多挑戰需要解決。GANs 目前也存在不少問題,比如無法區分在某個位置應該有多少特定的物體,不能應用到 3D 物體,以及也不能理解真實世界的整體結構。當然現在有大量研究正在研究如何解決上述問題,新的模型也取得更好的性能。
而最近谷歌為了讓 GANs 更容易實現,設計開發并開源了一個基于 TensorFlow 的輕量級庫–TFGAN。
根據谷歌的介紹,TFGAN 提供了一個基礎結構來減少訓練一個 GAN 模型的難度,同時提供非常好測試的損失函數和評估標準,以及給出容易上手的例子,這些例子強調了 TFGAN 的靈活性和易于表現的優點。
此外,還提供了一個教程,包含一個高級的 API 可以快速使用自己的數據集訓練一個模型。
上圖是展示了對抗損失在圖像壓縮方面的效果。最上方第一行圖片是來自 ImageNet 數據集的圖片,也是原始輸入圖片,中間第二行展示了采用傳統損失函數訓練得到的圖像壓縮神經網絡的壓縮和解壓縮效果,最底下一行則是結合傳統損失函數和對抗損失函數訓練的網絡的結果,可以看到盡管基于對抗損失的圖片并不像原始圖片,但是它比第二行的網絡得到更加清晰和細節更好的圖片。
TFGAN 既提供了幾行代碼就可以實現的簡答函數來調用大部分 GAN 的使用例子,也是建立在包含復雜 GAN 設計的模式化方式。這就是說,我們可以采用自己需要的模塊,比如損失函數、評估策略、特征以及訓練等等,這些都是獨立的模塊。TFGAN 這樣的設計方式其實就滿足了不同使用者的需求,對于入門新手可以快速訓練一個模型來看看效果,對于需要修改其中任何一個模塊的使用者也能修改對應模塊,而不會牽一發而動全身。
最重要的是,谷歌也保證了這個代碼是經過測試的,不需要擔心一般的 GAN 庫造成的數字或者統計失誤。
開始使用
首先添加以下代碼來導入 tensorflow 和 聲明一個 TFGAN 的實例:
import tensorflow as tf tfgan = tf.contrib.gan為何使用 TFGAN
- 采用良好測試并且很靈活的調用接口實現快速訓練生成器和判別器網絡,此外,還可以混合 TFGAN、原生 TensorFlow以及其他自定義框架代碼;
- 使用實現好的GAN 的損失函數和懲罰策略 (比如 Wasserstein loss、梯度懲罰等)
- 訓練階段對 GAN 進行監控和可視化操作,以及評估生成結果
- 使用實現好的技巧來穩定和提高性能
- 基于常規的 GAN 訓練例子來開發
- 采用GANEstimator接口里快速訓練一個 GAN 模型
- TFGAN 的結構改進也會自動提升你的 TFGAN 項目的性能
- TFGAN 會不斷添加最新研究的算法成果
TFGAN 的部件有哪些呢?
TFGAN 是由多個設計為獨立的部件組成的,分別是:
- core:提供了一個主要的訓練 GAN 模型的結構。訓練過程分為四個階段,每個階段都可以采用自定義代碼或者 調用 TFGAN 庫接口來完成;
- features:包含許多常見的 GAN 運算和正則化技術,比如實例正則化(instance normalization)
- losses:包含常見的 GAN 的損失函數和懲罰機制,比如 Wasserstein loss、梯度懲罰、相互信息懲罰等
- evaulation:使用一個預訓練好的 Inception 網絡來利用Inception Score或者Frechet Distance評估標準來評估非條件生成模型。當然也支持利用自己訓練的分類器或者其他方法對有條件生成模型的評估
- examples and tutorial:使用 TFGAN 訓練 GAN 模型的例子和教程。包含了使用非條件和條件式的 GANs 模型,比如 InfoGANs 等。
訓練一個 GAN 模型
典型的 GAN 模型訓練步驟如下:
當然,GAN 的設置有多種形式。比如,你可以在非條件下訓練生成器生成圖片,或者可以給定一些條件,比如類別標簽等輸入到生成器中來訓練。無論是哪種設置,TFGAN 都有相應的實現。下面將結合代碼例子來進一步介紹。
實例
非條件 MNIST 圖片生成
第一個例子是訓練一個生成器來生成手寫數字圖片,即 MNIST 數據集。生成器的輸入是從多變量均勻分布采樣得到的隨機噪聲,目標輸出是 MNIST 的數字圖片。具體查看論文“Generative Adversarial Networks”。代碼如下:
# 配置輸入 # 真實數據來自 MNIST 數據集 images = mnist_data_provider.provide_data(FLAGS.batch_size) # 生成器的輸入,從多變量均勻分布采樣得到的隨機噪聲 noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])# 調用 tfgan.gan_model() 函數定義生成器和判別器網絡模型 gan_model = tfgan.gan_model(generator_fn=mnist.unconditional_generator, discriminator_fn=mnist.unconditional_discriminator, real_data=images,generator_inputs=noise)# 調用 tfgan.gan_loss() 定義損失函數 gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.wasserstein_generator_loss,discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss)# 調用 tfgan.gan_train_ops() 指定生成器和判別器的優化器 train_ops = tfgan.gan_train_ops(gan_model,gan_loss,generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))# tfgan.gan_train() 開始訓練,并指定訓練迭代次數 num_steps tfgan.gan_train(train_ops,hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],logdir=FLAGS.train_log_dir)條件式 MNIST 圖片生成
第二個例子同樣還是生成 MNIST 圖片,但是這次輸入到生成器的不僅僅是隨機噪聲,還會給類別標簽,這種 GAN 模型也被稱作條件 GAN,其目的也是為了讓 GAN 訓練不會太過自由。具體可以看論文“Conditional Generative Adversarial Nets”。
代碼方面,僅僅需要修改輸入和建立生成器與判別器模型部分,如下所示:
# 配置輸入 # 真實數據來自 MNIST 數據集,這里增加了類別標簽--one_hot_labels images, one_hot_labels = mnist_data_provider.provide_data(FLAGS.batch_size) # 生成器的輸入,從多變量均勻分布采樣得到的隨機噪聲 noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])# 調用 tfgan.gan_model() 函數定義生成器和判別器網絡模型 gan_model = tfgan.gan_model(generator_fn=mnist.conditional_generator, discriminator_fn=mnist.conditional_discriminator, real_data=images,generator_inputs=(noise, one_hot_labels)) # 生成器的輸入增加了類別標簽# 剩余的代碼保持一致 ...對抗損失
第三個例子結合了 L1 pixel loss 和對抗損失來學習自動編碼圖片。瓶頸層可以用來傳輸圖片的壓縮表示。如果僅僅使用 pixel-wise loss,網絡只回傾向于生成模糊的圖片,但 GAN 可以用來讓這個圖片重建過程更加逼真。具體可以看論文“Full Resolution Image Compression with Recurrent Neural Networks”來了解如何用 GAN 來實現圖像壓縮,以及論文“Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network”了解如何用 GANs 來增強生成的圖片質量。
代碼如下:
# 配置輸入 images = image_provider.provide_data(FLAGS.batch_size)# 配置生成器和判別器網絡 gan_model = tfgan.gan_model(generator_fn=nets.autoencoder, # 自定義的 autoencoderdiscriminator_fn=nets.discriminator, # 自定義的 discriminatorreal_data=images,generator_inputs=images)# 建立 GAN loss 和 pixel loss gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.wasserstein_generator_loss,discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,gradient_penalty=1.0) l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)# 結合兩個 loss gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)# 剩下代碼保持一致 ...圖像轉換
第四個例子是圖像轉換,它是將一個領域的圖片轉變成另一個領域的同樣大小的圖片。比如將語義分割圖變成街景圖,或者是灰度圖變成彩色圖。具體細節看論文“Image-to-Image Translation with Conditional Adversarial Networks”。
代碼如下:
# 配置輸入,注意增加了 target_image input_image, target_image = data_provider.provide_data(FLAGS.batch_size)# 配置生成器和判別器網絡 gan_model = tfgan.gan_model(generator_fn=nets.generator, discriminator_fn=nets.discriminator, real_data=target_image,generator_inputs=input_image)# 建立 GAN loss 和 pixel loss gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.least_squares_generator_loss,discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss) l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)# 結合兩個 loss gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)# 剩下代碼保持一致 ...InfoGAN
最后一個例子是采用 InfoGAN 模型來生成 MNIST 圖片,但是可以不需要任何標簽來控制生成的數字類型。具體細節可以看論文“InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets”。
代碼如下:
# 配置輸入 images = mnist_data_provider.provide_data(FLAGS.batch_size)# 配置生成器和判別器網絡 gan_model = tfgan.infogan_model(generator_fn=mnist.infogan_generator, discriminator_fn=mnist.infogran_discriminator, real_data=images,unstructured_generator_inputs=unstructured_inputs, # 自定義輸入structured_generator_inputs=structured_inputs) # 自定義# 配置 GAN loss 以及相互信息懲罰 gan_loss = tfgan.gan_loss(gan_model,generator_loss_fn=tfgan_losses.wasserstein_generator_loss,discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,gradient_penalty=1.0,mutual_information_penalty_weight=1.0)# 剩下代碼保持一致 ...自定義模型的創建
最后同樣是非條件 GAN 生成 MNIST 圖片,但利用GANModel函數來配置更多參數從而更加精確控制模型的創建。
代碼如下:
# 配置輸入 images = mnist_data_provider.provide_data(FLAGS.batch_size) noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])# 手動定義生成器和判別器模型 with tf.variable_scope('Generator') as gen_scope:generated_images = generator_fn(noise) with tf.variable_scope('Discriminator') as dis_scope:discriminator_gen_outputs = discriminator_fn(generated_images) with variable_scope.variable_scope(dis_scope, reuse=True):discriminator_real_outputs = discriminator_fn(images) generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope)# 依賴于你需要使用的 TFGAN 特征,你并不需要指定 `GANModel`函數的每個參數,不過 # 最少也需要指定判別器的輸出和變量 gan_model = tfgan.GANModel(generator_inputs,generated_data,generator_variables,gen_scope,generator_fn,real_data,discriminator_real_outputs,discriminator_gen_outputs,discriminator_variables,dis_scope,discriminator_fn)# 剩下代碼和第一個例子一樣 ...最后,再次給出 TFGAN 的 Github 地址如下:
https://github.com/tensorflow/models/tree/master/research/gan
如果有翻譯不當的地方或者有任何建議和看法,歡迎留言交流;也歡迎關注我的微信公眾號–機器學習與計算機視覺或者掃描下方的二維碼,和我分享你的建議和看法,指正文章中可能存在的錯誤,大家一起交流,學習和進步!
總結
以上是生活随笔為你收集整理的谷歌开源的 GAN 库--TFGAN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python读取PDF文档并翻译
- 下一篇: Dataway让 Spring Boot