神经网络中BP(back propagation)到底在干些什么
前言
想要理解神經網絡的工作原理,反向傳播(BP)是必須搞懂的東西。BP其實并不難理解,說白了就是用鏈式法則(chain rule)算算算。本文試圖以某個神經網絡為例,盡可能直觀,詳細,明了地說明反向傳播的整個過程。
正向傳播
在反向傳播之前,必然是要有正向傳播的。正向傳播時的所有參數都是預先隨機取的,沒人能說這樣的參數好不好,得要試過才知道,試過之后,根據得到的結果與目標值的差距,再通過反向傳播取修正各個參數。下圖就是一個神經網絡,我們以整個為例子來說明整個過程
圖1:神經網絡圖我懶,此圖取自參考文獻[1],圖中的各個符號說明如下(順序從下往上):
xix_ixi?:輸入樣本中的第iii個特征的值
vihv_{ih}vih?:xix_ixi?與隱層第hhh個神經元連接的權重
αh\alpha_hαh?:第h個隱層神經元的輸入,αh=∑i=1dvihxi\alpha_h=\sum_{i=1}^d v_{ih}x_iαh?=∑i=1d?vih?xi?
bhb_hbh?:第h個隱層神經元的輸出,某個神經元的輸入和輸出有關系f(αh)=bhf(\alpha_h)=b_hf(αh?)=bh?,其中f(x)f(x)f(x)為激活函數,比如Sigmoid函數f(x)=11+e?xf(x)=\dfrac{1}{1+e^{-x}}f(x)=1+e?x1?
whjw_hjwh?j:隱層第hhh個神經元和輸出層第jjj個神經元連接的權重
βj\beta_jβj?:輸出層第jjj個神經元的輸入,βj=∑h=1qwhjbh\beta_j=\sum_{h=1}^q w_{hj}b_hβj?=∑h=1q?whj?bh?
yjy_jyj?:第jjj個輸出層神經元的輸出,f(βj)=yjf(\beta_j)=y_jf(βj?)=yj?,f(x)f(x)f(x)為激活函數
為了方便書寫,我們假設截距項bias已經在參數www和vvv之中了,也就是說在輸入數據的時候,我們增添了一個x0=1x_0=1x0?=1,由于我懶,圖中沒有畫出來,但心里要清楚這一點。
相信看了圖之后,神經網絡的正向傳播就相當簡單明了了,不過,這里我還是啰嗦一句,舉個例子,比如輸出yjy_jyj?的計算方法為
yj=f(βj)=f(∑h=1qwhjbh)=f(∑h=1qwhjf(αh))=f(∑h=1qwhjf(∑i=1dvihxi))y_j=f(\beta_j)=f(\sum_{h=1}^q w_{hj}b_h)=f(\sum_{h=1}^q w_{hj}f(\alpha_h))=f(\sum_{h=1}^q w_{hj}f(\sum_{i=1}^d v_{ih}x_i))yj?=f(βj?)=f(h=1∑q?whj?bh?)=f(h=1∑q?whj?f(αh?))=f(h=1∑q?whj?f(i=1∑d?vih?xi?))
反向傳播
好了,通過正向傳播,我們就已經得到了lll個yyy的值了,將它們與目標值ttt,也就是我們期望它們成為的值作比較,并放入損失函數中,記作LLL。
損失LLL可以自行選擇,比如常見的均方誤差L=12∑j=1l(yj?tj)2L=\dfrac{1}{2}\sum_{j=1}^l (y_j - t_j)^2L=21?∑j=1l?(yj??tj?)2
利用這個誤差,我們將進行反向傳播,以此來更新參數www和vvv。更新時,我們采用的是梯度下降法,也就是
{w:=w+Δwv:=v+Δv\begin{cases}w := w + \Delta w \\ v := v + \Delta v\end{cases}{w:=w+Δwv:=v+Δv?
其中,Δw=?η?L?w\Delta w = -\eta \dfrac{\partial L}{\partial w}Δw=?η?w?L?,Δv=?η?L?v\Delta v = -\eta \dfrac{\partial L}{\partial v}Δv=?η?v?L?,η\etaη為學習率。
下面要做的工作就是計算出每個參數的梯度,這也就是鏈式法則發揮作用的地方了。
比如,我們要計算whjw_{hj}whj?。從網絡結構中不難看出whjw_{hj}whj?影響了βj\beta_jβj?從而影響了yjy_jyj?,最終影響了LLL所以我們有
Δwhj=?η?βj?whj?yj?βj?L?yj\Delta w_{hj}=-\eta \dfrac{\partial \beta_j}{\partial w_{hj}} \dfrac{\partial y_j}{\partial \beta_j} \dfrac{\partial L}{\partial y_j}Δwhj?=?η?whj??βj???βj??yj???yj??L?
只要確定了損失函數LLL和激活函數f(x)f(x)f(x),上面所有的都是可以算的,而且?βh?whj=bh\dfrac{\partial \beta_h}{\partial w_{hj}} = b_h?whj??βh??=bh?這點是顯而易見的。并且,?yj?βj=?f(βj)?βj\dfrac{\partial y_j}{\partial \beta_j} = \dfrac{\partial f(\beta_j)}{\partial \beta_j}?βj??yj??=?βj??f(βj?)?就是激活函數的導數。
同理,vihv_{ih}vih?影響了αh\alpha_hαh?,從而影響了bhb_hbh?,從而影響了β1\beta_{1}β1?,β2\beta_{2}β2?,…,βl\beta_{l}βl?,從而影響了y1y_1y1?,y2y_2y2?,…,yly_lyl?,最終影響了LLL。
Δvih=?η?αh?vih?bh?αh∑j=1l(?βj?bh?yj?βj?L?yj)\Delta v_{ih} = -\eta \dfrac{\partial \alpha_h}{\partial v_{ih}} \dfrac{\partial b_h}{\partial \alpha_h}\sum_{j=1}^l (\dfrac{\partial \beta_j}{\partial b_h} \dfrac{\partial y_j}{\partial \beta_j} \dfrac{\partial L}{\partial y_j})Δvih?=?η?vih??αh???αh??bh??j=1∑l?(?bh??βj???βj??yj???yj??L?)
其中,?αh?vih=xi\dfrac{\partial \alpha_h}{\partial v_{ih}}=x_i?vih??αh??=xi?,?βj?bh=whj\dfrac{\partial \beta_j}{\partial b_h} = w_{hj}?bh??βj??=whj?,?yj?βj=?f(βj)?βj\dfrac{\partial y_j}{\partial \beta_j} = \dfrac{\partial f(\beta_j)}{\partial \beta_j}?βj??yj??=?βj??f(βj?)?和?bh?αh=?f(αh)?αh\dfrac{\partial b_h}{\partial \alpha_h} = \dfrac{\partial f(\alpha_h)}{\partial \alpha_h}?αh??bh??=?αh??f(αh?)?是激活函數的導數。
至此,我們已經可以算出Δw\Delta wΔw和Δv\Delta vΔv,從而更新參數了。
關于激活函數的幾點說明
從推出的公式中不難看出,隨著反向傳播向輸出層這個方向的推進,激活函數的影響也就越來越來了。通俗一點來說,在計算Δwhj\Delta w_{hj}Δwhj?,我們只乘了一個激活函數的導數,然而在計算Δvih\Delta v_{ih}Δvih?時,我們乘了多個激活函數的導數。
Δwhj=?η?βj?whjf′(βj)?L?yj\Delta w_{hj}=-\eta \dfrac{\partial \beta_j}{\partial w_{hj}} f'(\beta_j) \dfrac{\partial L}{\partial y_j}Δwhj?=?η?whj??βj??f′(βj?)?yj??L?
Δvih=?η?αh?vihf′(αh)∑j=1l(?βj?bhf′(βj)?L?yj)\Delta v_{ih} = -\eta \dfrac{\partial \alpha_h}{\partial v_{ih}} f'(\alpha_h) \sum_{j=1}^l (\dfrac{\partial \beta_j}{\partial b_h} f'(\beta_j) \dfrac{\partial L}{\partial y_j})Δvih?=?η?vih??αh??f′(αh?)j=1∑l?(?bh??βj??f′(βj?)?yj??L?)
不難推斷出,如果隱層的層數更多的話,激活函數的影響還要更大。
一個比較傳統的激活函數時Sigmoid函數,其圖像如下所示。
不難發現,當xxx比較大的時候,或比較小的時候,f′(x)f'(x)f′(x)是趨近于0的,當神經網絡的層數很深的時候,這么多個接近0的數相乘就會導致傳到輸出層這邊的時候已經沒剩下多少信息了,這時梯度對模型的更新就沒有什么貢獻了。那么大多數神經元將會飽和,導致網絡就幾乎不學習。這其實也是Sigmoid函數現在在神經網絡中不再受到青睞的原因之一。
另一個原因是Sigmoid 函數不是關于原點中心對稱的,這會導致梯度在反向傳播過程中,要么全是正數,要么全是負數。導致梯度下降權重更新時出現 Z 字型的下降。
所以,就出現了ReLU這個激活函數 f(x)=max?(0,x)f\left( x\right) =\max \left( 0,x\right)f(x)=max(0,x),其圖像如下圖所示。
ReLU 對于 SGD 的收斂有巨大的加速作用,而且只需要一個閾值就可以得到激活值,而不用去算一大堆復雜的(指數)運算。
不過,由于它左半邊的狀態,ReLU在訓練時比較脆弱并且可能“死掉”。
因此,人們又研究出了Leaky ReLU,PReLU等等的激活函數。這里不展開討論。
參考文獻
[1] 周志華. 機器學習 : = Machine learning[M]. 清華大學出版社, 2016.
[2] http://cs231n.github.io/neural-networks-1/
[2] http://www.jianshu.com/p/6df4ab7c235c
總結
以上是生活随笔為你收集整理的神经网络中BP(back propagation)到底在干些什么的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 天池 在线编程 最长AB子串(哈希)
- 下一篇: LeetCode 1769. 移动所有球