tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)
生活随笔
收集整理的這篇文章主要介紹了
tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
tensorflow 在實現 Batch Normalization(各個網絡層輸出的歸一化)時,主要用到以下兩個 api:
- tf.nn.moments(x, axes, name=None, keep_dims=False) ? mean, variance:
- 統計矩,mean 是一階矩,variance 則是二階中心矩
- tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)
- tf.nn.batch_norm_with_global_normalization(t, m, v, beta, gamma, variance_epsilon, scale_after_normalization, name=None)
- 由函數接口可知,tf.nn.moments 計算返回的 mean 和 variance 作為 tf.nn.batch_normalization 參數進一步調用;
1. tf.nn.moments,矩
tf.nn.moments 返回的 mean 表示一階矩,variance 則是二階中心矩;
如我們需計算的 tensor 的 shape 為一個四元組 [batch_size, height, width, kernels],一個示例程序如下:
import tensorflow as tf shape = [128, 32, 32, 64] a = tf.Variable(tf.random_normal(shape)) # a:activations axis = list(range(len(shape)-1)) # len(x.get_shape()) a_mean, a_var = tf.nn.moments(a, axis)這里我們僅給出 a_mean, a_var 的維度信息,
sess = tf.Session() sess.run(tf.global_variables_initalizer())sess.run(a_mean).shape # (64, ) sess.run(a_var).shape # (64, ) ? 也即是以 kernels 為單位,batch 中的全部樣本的均值與方差references
- 談談Tensorflow的Batch Normalization
轉載于:https://www.cnblogs.com/mtcnn/p/9421623.html
總結
以上是生活随笔為你收集整理的tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: C#泛型方法解析
- 下一篇: Android 之UID and PID