深度学习(Deep Learning):循环神经网络一(RNN)
原址:https://blog.csdn.net/fangqingan_java/article/details/53014085
概述
循環(huán)神經網絡(RNN-Recurrent Neural Network)是神經網絡家族中的一員,擅長于解決序列化相關問題。包括不限于序列化標注問題、NER、POS、語音識別等。RNN內容比較多,分成三個小節(jié)進行介紹,內容包括RNN基礎以及求解算法、LSTM以及變種GRU、RNN相關應用。本節(jié)主要介紹
1.RNN基礎知識介紹?
2.RNN模型優(yōu)化以及存在的問題?
3.RNN模型變種
RNN知識點
RNN提出動機
RNN的提出可以有效解決以下問題:
編碼:可以將可變輸入編碼成固定長度的向量。和CNN相比,能夠保留全局最優(yōu)特征。
計算圖展開
RNN常用以下公式獲取歷史狀態(tài)
ht=f(ht?1,xt;θ)ht=f(ht?1,xt;θ)
其中h為隱藏層,用于保存上下文信息,f是激活函數(shù)。?
用圖模型可以表達為:?
?
RNN潛在可能的展開方式如下:?
1)通過隱藏層傳遞信息?
1.該展開形式非常常用,主要包括三層輸入-隱藏層、隱藏層-隱藏層、隱藏層到輸入層。依賴信息通過隱藏層進行傳遞。?
2.參數(shù)U、V、W為共享參數(shù)
2)輸出節(jié)點連接到下一時序序列?
應用比較局限,上一時序的輸出作為下一時間點的輸入,理論上上一時間點的輸出比較固定,能夠攜帶的信息比較少。
3)只有一個輸出節(jié)點?
只在最后時間點t產生輸出,往往能夠將變成的輸入轉換為固定長度的向量表示。
RNN使用形式
在使用RNN時,主要形式有4中,如下圖所示。?
1.一對一形式(左一:Many to Many)每一個輸入都有對應的輸出。?
2.多對一形式(左二:Many to one)整個序列只有一個輸出,例如文本分類、情感分析等。?
3. 一對多形式(左三:One to Many)一個輸入產出一個時序序列,常用于seq2seq的解碼階段?
4.多對多形式(左四:Many to Many)不是每一個輸入對應一個輸出,對應到變成的輸出。
RNN數(shù)學表達以及優(yōu)化
RNN前向傳播
對于離散時間的RNN問題可以描述為,輸入序列
(x1,y1),(x2,y2),(x3,y3)......(xT,yT)(x1,y1),(x2,y2),(x3,y3)......(xT,yT)
其中時間參數(shù)t表示離散序列,不一定是真實時間點。?
對于多分類問題,目標是最小化釋然函數(shù)?
min∑t=1TL(y^(xt),yt)=min?∑tlog?p(yt|x1,x2...xt)min∑t=1TL(y^(xt),yt)=min?∑tlog?p(yt|x1,x2...xt)
?
根據上面經典的RNN網絡結構,前向傳播過程如下:?
如上圖U、V、W分別表示輸入到隱藏層、隱藏層到輸出以及隱藏到隱藏層的連接參數(shù)。?
1. 隱藏層節(jié)點權值:at=b+Wht?1+Uxtat=b+Wht?1+Uxt?
2. 隱藏層非線性變換:?ht=tanh(at)ht=tanh(at)?
3. 輸出層:?ot=c+Vhtot=c+Vht?
4. softmax層:?y^t=softmax(ot)y^t=softmax(ot)
RNN優(yōu)化算法-BPTT
BPTT 是求解RNN問題的一種優(yōu)化算法,也是基于BP算法改進得到和BP算法比較類似。為直觀上理解通過多分類問題進行簡單推導。?
1. 優(yōu)化目標,對于多分類問題,BPTT優(yōu)化目標轉換最小化交叉熵:
min∑tLtLt=?∑kytklogy^tkmin∑tLtLt=?∑kyktlogy^kt
這里假設有k個類?
2. 由于總的損失L為各個時序點的損失和,因此有
?L?Lt=1?L?Lt=1
3. 對于輸出層中的第i節(jié)點有
(?otL)i=?L?oti=?L?Lt?Lt?oti=y^ti?1i,yt(?otL)i=?L?oit=?L?Lt?Lt?oit=y^it?1i,yt
最后一步是交叉熵推導結果,步驟省略,了解softmax的都清楚。1i,yt1i,yt表示如果y^t==i則為1,否則為0?
4. 隱藏層節(jié)點梯度的計算,分為兩部分,第一部分 t=T。
(?hTL)i=∑j(?oTL)j?oTj?hTi=∑j(?oTL)jVij(?hTL)i=∑j(?oTL)j?ojT?hiT=∑j(?oTL)jVij
通過向量的方式表達為
(?hTL)=(?oTL)?oT?hT=(?oTL)V(?hTL)=(?oTL)?oT?hT=(?oTL)V
5.第二部分, 中間節(jié)點?t<Tt<T,對于中間節(jié)點需要考慮t+1以及以后時間點傳播的誤差,因此計算過程如下。
(?htL)i=∑j(?ht+1L)j?ht+1j?hti+∑k(?otL)k?otk?hti=隱藏層誤差反饋+輸出層誤差反饋=∑j(?ht+1L)j?ht+1j?at+1j?at+1j?hti+∑k(?otL)kVki=∑j(?ht+1L)j(1?ht+1j2)Wji+∑k(?otL)kVki=(?ht+1L)diag((1?ht+12))Wi+(?otL)Vi(?htL)i=∑j(?ht+1L)j?hjt+1?hit+∑k(?otL)k?okt?hit=隱藏層誤差反饋+輸出層誤差反饋=∑j(?ht+1L)j?hjt+1?ajt+1?ajt+1?hit+∑k(?otL)kVki=∑j(?ht+1L)j(1?hjt+12)Wji+∑k(?otL)kVki=(?ht+1L)diag((1?ht+12))Wi+(?otL)Vi
通過向量表示如下:
(?htL)=(?ht+1L)?ht+1?ht+(?otL)?ot?ht=(?ht+1L)diag((1?ht+12))W+(?otL)V(?htL)=(?ht+1L)?ht+1?ht+(?otL)?ot?ht=(?ht+1L)diag((1?ht+12))W+(?otL)V
其中diag((1?ht+12))diag((1?ht+12))是由1?ht+1i1?hit+1的平方組成的對角矩陣。?
6.根據中間結果的梯度可以推導出其他參數(shù)的梯度,結果如下
?cL?bL?VL?WL?UL=∑t(?toL)?ot?c=∑t(?toL)=∑t(?thL)?ht?b=∑t(?thL)diag((1?ht2))=∑t(?toL)?ot?V=∑t(?toL)htT=∑t(?thL)?ht?W=∑t(?thL)diag((1?ht2))ht?1T=∑t(?thL)?ht?U=∑t(?thL)diag((1?ht2))xtT?cL=∑t(?otL)?ot?c=∑t(?otL)?bL=∑t(?htL)?ht?b=∑t(?htL)diag((1?ht2))?VL=∑t(?otL)?ot?V=∑t(?otL)htT?WL=∑t(?htL)?ht?W=∑t(?htL)diag((1?ht2))ht?1T?UL=∑t(?htL)?ht?U=∑t(?htL)diag((1?ht2))xtT
7. 到此完成了對所有參數(shù)梯度的推導。
?
梯度彌散和爆炸問題
RNN訓練比較困難,主要原因在于隱藏層參數(shù)W,無論在前向傳播過程還是在反向傳播過程中都會乘上多次。這樣就會導致1)前向傳播某個小于1的值乘上多次,對輸出影響變小。2)反向傳播時會導致梯度彌散問題,參數(shù)優(yōu)化變得比較困難。?
可以通過梯度公式也可以看出梯度彌散或者爆炸問題。?
考慮到通用性,激活函數(shù)采用f(x)代替,則對隱藏層到隱藏層參數(shù)W梯度公式如下:?
?WL=∑t(?thL)?ht?W=∑t(?thL)diag(f′(ht))ht?1?WL=∑t(?htL)?ht?W=∑t(?htL)diag(f′(ht))ht?1
后面部分可以直接得到,下面詳細分析它的系數(shù)(?thL)(?htL)
?
1.考慮當t=T,即為最后一個節(jié)點時,根據上面的推導有
(?hTL)=(?oTL)?oT?hT=(?oTL)V(?hTL)=(?oTL)?oT?hT=(?oTL)V
2.當t=T-1時,
(?hT?1L)=(?ThL)?ht+1?ht=(?hTL)diag(f′(hT))W(?hT?1L)=(?hTL)?ht+1?ht=(?hTL)diag(f′(hT))W
注這里只考慮隱藏層節(jié)點對W的誤差傳遞,沒有考慮輸出層。?3. 當t=T-2時,
(?hT?2L)=(?T?1hL)?hT?1?hT?2=(?hTL)diag(f′(hT))Wdiag(f′(hT?1))W=(?hTL)diag(f′(hT))diag(f′(hT?1))W2(?hT?2L)=(?hT?1L)?hT?1?hT?2=(?hTL)diag(f′(hT))Wdiag(f′(hT?1))W=(?hTL)diag(f′(hT))diag(f′(hT?1))W2
4. 當t=k時
(?hkL)=(?ThL)∏j=k+1T?hj?hj?1=(?hTL)∏j=kTdiag(f′(hj))W(?hkL)=(?hTL)∏j=k+1T?hj?hj?1=(?hTL)∏j=kTdiag(f′(hj))W
5.此時diag(f′(hj))Wdiag(f′(hj))W的結果是一個對角矩陣,如果其中某個元素大于1,則該值會指數(shù)倍放大;否則會以指數(shù)倍縮小。?
6.因此可以看出當序列比較長,即模型有長期依賴問題時,就會產生梯度相關問題。一般情況下BPTT對于序列長度在100以內,不會暴露問題。?
7.需要注意的是,如果我們的訓練樣本被人工分為子序列,且長度都較小時,不會產生梯度問題。此時比較依賴于前期預處理
?
梯度問題解決方案
梯度爆炸問題方案
該問題采用截斷的方式有效避免,并且取得較好的效果。?
梯度彌散問題解決方案
針對該問題,有大量的解決方法,效果不一致。?
1.有效初始化+ReLU激活函數(shù)能夠得到較好效果?
2.算法上的優(yōu)化,例如截斷的BPTT算法。?
3.模型上的改進,例如LSTM、GRU單元都可以有效解決長期依賴問題。?
4.在BPTT算法中加入skip connection,此時誤差可以間歇的向前傳播。?
5.加入一些Leaky Units,思路類似于skip connection
RNN模型改進
主要有兩大類思路
雙向RNN(Bi-RNN)
此時不僅可以依賴前面的上下文,還可以依賴后面的上下文。?
深度RNN(Deep-RNN)
有多種方式進行深度RNN的組合,左一比較常用。?
總結
通過該小結的總結,可以了解到?
1)RNN模型優(yōu)勢以及處理問題形式。?
2)標準RNN的數(shù)學公式以及BPTT推導?
3)RNN模型訓練中的梯度問題以及如何避免
總結
以上是生活随笔為你收集整理的深度学习(Deep Learning):循环神经网络一(RNN)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pycharm同一目录下无法import
- 下一篇: tf.slice解析