神经网络训练细节之batch normalization
在對神經網絡進行訓練的時候,經常涉及到的一個概念就是batch normalization,那么究竟什么是batch normalization,以及為什么在訓練的時候要加入batch normalization呢?以下就是我的一些學習心得。
1、提出batch normalization的原因
? ? ? (1) 我們在對某個神經網絡訓練完成之后,需要測試該網絡,那么如果我們的訓練集數據與測試集數據具有不同的分布,最終的測試結果會有比較大的影響。也就是說我們的模型是基于訓練集數據進行訓練的,整個訓練的過程肯定會受到訓練集數據分布的影響,這種分布的影響也會體現到模型的訓練中,但如果測試集數據具有與訓練集數據不一樣的樣本,那么我們的模型對于最后的測試集數據的輸出結果可能就是錯誤的。所以我們希望不論是訓練還是測試的時候,輸入數據的分布最好都是一致并且穩定的,當然一般來說,數據訓練集與測試集的數據分布都是一致的,因為如果分布差距很大,那么這就不能看作是同一個問題。
? ? ? ?(2)當網絡比較深的時候,即使我們將batch規范化成均值為0,單位方差的數據輸入,但是后面層的輸出就不能保證了,隨著網絡的深入,后面網絡的輸出將不再滿足均值為0,方差為1。這也就是說網絡在訓練的時候,每次迭代輸入該層網絡的數據的分布都不一樣,這就使得網絡在訓練的過程中有點無所適從,相應收斂的速度也會降低,針對這個問題,就考慮是否可以在每一層的輸出后面都加一個BN(batch normalization)層,從而使得每層輸入數據都是零均值,單位方差的數據,從而可以使得整個網絡的收斂更快。
? ? ? (3)還有一個原因是促使batch normalization提出的又一原因,在卷積神經網絡的訓練中,往往要對圖像進行“白化”處理,這樣可以加快網絡訓練時候的收斂速度。“白化”操作之所以會加快收斂速度,原因是“白化”處理的本質是去除數據之間的相關性,這樣就簡化了后續數據獨立分量的提取過程。
2、什么是batch normalization?
? ? ?batch normalization其實就是對數據進行規范化,將其分布變成均值為0,單位方差的正太分布。實現這一功能的方法非常簡單,公式如下:
但是如果僅僅進行這樣的操作,就會降低整個網絡的表達能力。距離來說,加入激活函數采用的是sigmoid函數,那么當輸入數據滿足均值為0,單位方差的分布時,由于sigmoid函數的曲線特性(可見博客《關于激勵函數的一些思考》),在0附近的范圍內,整個sigmoid曲線接近線性,激勵函數的非線性特性就得不到體現,這樣整個網絡的表達能力會降低,為了改善這個情況,在對數據采用規范化的過程中引入兩個參數,對數據進行縮放和平移,具體公式如下:
這兩個公式中涉及的均值跟方差都是針對所有數據的,但在實際訓練的時候,我們是對訓練數據進行洗牌,并隨機抽取一個mini-batch進行訓練,所以在實際的訓練中,我們只是在mini-batch上實現對數據的規范,公式如下:
3、測試時候的batch normalization
? ? ?那么對每層的輸入數據進行batch normalization之后,我們可以更加快速地實現收斂,那么當整個模型訓練成功之后,我們需要對網絡進行測試,這時候輸入的單個的樣本,只有一個數據,如何對這個輸入樣本進行規范呢?
? ? ?答案就是我們在訓練的時候要記住每個mini-batch的均值與方差,然后根據這些數據計算訓練集整體的均值與方差,公式如下:
利用整體的均值與方差,實現對單個樣本的規范化,然后再輸入到訓練好的網絡中進行測試。
總結
以上是生活随笔為你收集整理的神经网络训练细节之batch normalization的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ZED ROS包发布topic介绍
- 下一篇: java paint调用,求教 如何调用