深度学习04-RNN
文章目錄
- 1 為什么需要RNN
- 1.1RNN的應用場景
- 1.2 DNN和CNN不能解決的問題
- 2 RNN的網絡結構
- 2.1 RNN基礎結構
- 2.2 不同類型的RNN
- 3 RNN的優化算法BPTT
- 4 LSTM
- 5 GRU
1 為什么需要RNN
1.1RNN的應用場景
1 模仿論文(生成序列)。輸入是一堆的論文文章,輸出是符合論文格式的文本。
2 模仿linux 內核代碼寫程序(生成序列)
3 模仿小四寫文章(生成文本序列)
4 機器翻譯
5 image to text 看圖說話
1.2 DNN和CNN不能解決的問題
深度神經網絡DNN是上面這個樣子。前一層輸出是后一層輸入。每一層的輸入輸出是獨立的。第n層的輸出和第n+1層的輸出是獨立的,是沒有關系的。CNN也一樣。例如一張圖像中要畫出貓和狗的位置,那貓和狗是獨立的,是用不同的神經元捕獲特征。不會去根據貓的位置或者特征推測狗的位置。
但有些任務中后續的輸出和之前的內容是有關系的。例如完形填空:我是中國人,我的母語是_____。RNN就是用來解決這類問題。
2 RNN的網絡結構
2.1 RNN基礎結構
RNN網絡結構的特點是每一層網絡執行相同的任務,但是輸出依賴于輸入和記憶。
W,U,V是三個權重向量(是向量還是矩陣?),并且在所有網絡層,值是相同的。
xtx_txt?是t時刻的輸入
StS_tSt?是t時刻的記憶:St=f(UXt+WSt?1)S_t=f(UX_t+WS_{t-1})St?=f(UXt?+WSt?1?),f可以是tanh等函數,這個函數應該是一個值域范圍固定的函數,例如函數范圍在(-1,1)之間。這樣可以保證神經網絡不會爆炸
OtO_tOt?是t時刻的輸出,如果是輸出下個詞的話,那就是輸出每個候選詞的概率,Ot=Softmax(VSt)O_t=Softmax(VS_t)Ot?=Softmax(VSt?)
我們用高中學習的類比。如果t=高三,那么
W,U,V是我們的學習方法,高一,高二,高三這三年學習方法不變(在一輪迭代中)。
xtx_txt?是高三這一年老師教給我們的知識。
StS_tSt?是高三學習完以后能夠記住的知識。我們能記住的知識取決于高二學習后能記住的知識St?1S_{t-1}St?1?和高三這一年老師能交給我們的知識xtx_txt?。
OtO_tOt?可以是高三畢業考試的成績,它與高三學習完以后能夠記住的知識有關。當然成績是一個線性回歸問題,與上面例子中說的多分類問題是兩種類型的問題。
由于每一層共享參數W、U、V,所以RNN的參數量與CNN相比,是比較小的。
在有些問題中不一定有OtO_tOt?。例如情感分類的任務中,只需要在讀完所有句子,也就是最后一個時刻輸出情感類別即可,過程中不需要。
StS_tSt?并不能捕捉所有時刻的信息,StS_tSt?是一個矩陣,能夠存儲的信息是有限的。
示例代碼:唐詩生成器
2.2 不同類型的RNN
1 深層雙向RNN
在有些情況下,當前的輸出不僅依賴于之前序列的元素,還與之后的元素有關。例如在句法解析中。“He said, Teddy bears are on sale” and “He said, Teddy Roosevelt was a great President。在上面的兩句話中,當我們看到“Teddy”和前兩個詞“He said”的時候,我們有可能無法理解這個句子是指President還是Teddy bears。因此,為了解決這種歧義性,我們需要往后查找。
S?t=f(W?xt+V?S?t?1+b?)\vec S_t = f(\vec Wx_t+\vec V\vec S_{t-1}+\vec b)St?=f(Wxt?+VSt?1?+b)
S←t=f(W←xt+V←S←t+1+b←)\overleftarrow{S}_{t}=f\left(\overleftarrow{W} x_{t}+\overleftarrow{V} \overleftarrow{S}_{t+1}+\overleftarrow{b}\right)St?=f(Wxt?+VSt+1?+b)
yt=g(U[S?t;S←t]+c)y_{t}=g\left(U\left[\vec{S}_{t} ; \overleftarrow{S}_t\right]+c\right)yt?=g(U[St?;St?]+c)
從左向右計算記憶S?t\vec S_tSt?,從右向左計算記憶S←t\overleftarrow{S}_tSt?,U[S?t;S←t]U\left[\vec{S}_{t} ; \overleftarrow{S}_t\right]U[St?;St?]是對兩個矩陣做拼接。
2 深層雙向RNN
圖中的h和之前的S是等價的。
這樣的網絡是說在每個時刻不僅學習一遍,可以學習3遍甚至更多。類比于,你讀了三遍高一,三遍高二,三遍高三。
3 RNN的優化算法BPTT
BPTT和BP很類似,是一個思路,但是因為這里和時刻有關系。
在這樣一個多分類器中,損失函數是一個交叉熵。
某一時刻的損失函數是:Et(yt,y^t)=?ytlog?y^tE_{t}\left(y_{t}, \hat{y}_{t}\right)=-y_{t} \log \hat{y}_{t}Et?(yt?,y^?t?)=?yt?logy^?t?
最終的損失函數是所有時刻的交叉熵相加:E(y,y^)=∑tEt(yt,y^t)=?∑ytlog?y^t\begin{aligned} E(y, \hat{y}) &=\sum_{t} E_{t}\left(y_{t}, \hat{y}_{t}\right) \\ &=-\sum y_{t} \log \hat{y}_{t} \end{aligned}E(y,y^?)?=t∑?Et?(yt?,y^?t?)=?∑yt?logy^?t??
損失函數對W求偏導:?E?W=∑t?Et?W\frac{\partial E}{\partial W}=\sum_{t} \frac{\partial E_{t}}{\partial W}?W?E?=∑t??W?Et??
假設t=3,?E3?W=?E3?y^3?y^3?s3?s3?W\frac{\partial E_{3}}{\partial W}=\frac{\partial E_{3}}{\partial \hat{y}_{3}} \frac{\partial \hat{y}_{3}}{\partial s_{3}} \frac{\partial s_{3}}{\partial W}?W?E3??=?y^?3??E3???s3??y^?3???W?s3??
E3E_3E3?和y3y_3y3?有關系,y3y_3y3?和s3s_3s3?有關系(參考2.1中的公式)。
而s3=tanh(Ux3+Ws2)s_3=tanh(Ux_3+Ws_2)s3?=tanh(Ux3?+Ws2?),s3s_3s3?和s2s_2s2?有關系,我們對s3s_3s3?對W求偏導不能直接等于s2s_2s2?,因為s2s_2s2?也和W有關系。
s2=tanh(Ux2+Ws1)s_2=tanh(Ux_2+Ws_1)s2?=tanh(Ux2?+Ws1?)
s2s_2s2?和s1s_1s1?有關系…一直到0時刻。所以我們會把每個時刻的相關梯度值相加:?s3?W=∑k=03?s3?sk?sk?W\frac{\partial s_{3}}{\partial W}=\sum_{k=0}^{3} \frac{\partial s_{3}}{\partial s_{k}} \frac{\partial s_{k}}{\partial W}?W?s3??=k=0∑3??sk??s3???W?sk??
至于這里為什么要把每個時刻的梯度相加可以參考文檔,這里直接就是說相加。還有一些解釋是:因為分子是向量,分母是矩陣,需要拆開來求導。或者根本上來講是因為求導公式,我暫時沒弄明白這一步。
其中我們在計算?s3?s2\dfrac{\partial s_3}{\partial s_2}?s2??s3??的時候需要使用鏈式法則計算:?s3?s1=?s3?s2?s2?s1?s1?s0\dfrac{\partial s_3}{\partial s_1}=\dfrac{\partial s_3}{\partial s_2}\dfrac{\partial s_2}{\partial s_1}\dfrac{\partial s_1}{\partial s_0}?s1??s3??=?s2??s3???s1??s2???s0??s1??
所以最終得到:?E3?W=∑k=03?E3?y^3?y^3?s3?s3?sk?sk?W=∑k=03?E3?y^3?y^3?s3(∏j=k+13?sj?sj?1)?sk?W\frac{\partial E_{3}}{\partial W}=\sum_{k=0}^{3} \frac{\partial E_{3}}{\partial \hat{y}_{3}} \frac{\partial \hat{y}_{3}}{\partial s_{3}} \frac{\partial s_{3}}{\partial s_{k}} \frac{\partial s_{k}}{\partial W} =\sum_{k=0}^{3} \frac{\partial E_{3}}{\partial \hat{y}_{3}} \frac{\partial \hat{y}_{3}}{\partial s_{3}}\left(\prod_{j=k+1}^{3} \frac{\partial s_{j}}{\partial s_{j-1}}\right) \frac{\partial s_{k}}{\partial W}?W?E3??=k=0∑3??y^?3??E3???s3??y^?3???sk??s3???W?sk??=k=0∑3??y^?3??E3???s3??y^?3?????j=k+1∏3??sj?1??sj??????W?sk??
看公式中有連乘的部分。當使用tanh作為激活函數的時候,由于導數值分別在0到1之間,隨著時間的累計,小于1的數不斷相城,很容易趨近于0。(另外一種解釋:如果權重矩陣 W的范數也不很大,那么經過 𝑡?𝑘 次傳播后,?s3?sk\dfrac{\partial s_3}{\partial s_k}?sk??s3??的范數會趨近于0,這也就導致了梯度消失。)
梯度消失帶來的一個問題就是記憶力有限,離得越遠的東西記住得越少。
4 LSTM
LSTM就是為了解決普通RNN中的梯度消失問題提出的。
LSTM提出了記憶細胞C,以及各種門。下圖中的h與上面的S是相同含義,表示記憶。每個時刻的輸出,在這里是沒有畫出來的。
假設現在有一個任務是根據已經讀到的詞,預測下一個詞。例如輸入法,生成詩詞。
第1步:忘記門:從記憶細胞中丟棄一些信息
使用sigmoid函數,經過sigmoid之后得到一個概率值,描述每個部分有多少量可以通過。
ft=σ(Wf?[ht?1,xt]+bf)f_{t}=\sigma\left(W_{f} \cdot\left[h_{t-1}, x_{t}\right]+b_{f}\right)ft?=σ(Wf??[ht?1?,xt?]+bf?)
如果C中包含當前對象的性別屬性,現在已經正確的預測了當前的名詞。當我們看到另外一個新的對象的時候,我們希望忘記舊對象的性別屬性。
第2步:更新什么新信息到記憶中
sigmoid決定什么值需要更新: it=σ(Wi?[ht?1,xt]+bi)i_{t}=\sigma\left(W_{i} \cdot\left[h_{t-1}, x_{t}\right]+b_{i}\right)it?=σ(Wi??[ht?1?,xt?]+bi?)
tanh層創建一個新的候選值向量(高三這一年學到的所有知識): C~t=tanh?(WC?[ht?1,xt]+bC)\tilde{C}_{t}=\tanh \left(W_{C} \cdot\left[h_{t-1}, x_{t}\right]+b_{C}\right)C~t?=tanh(WC??[ht?1?,xt?]+bC?)
第3步:更新記憶細胞
把舊狀態與ftf_tft?相乘,丟棄掉我們確定需要丟棄的信息;
加上iti_tit?*C~t\tilde{C}_{t}C~t?,就是新的候選值,更新狀態。
Ct=ft?Ct?1+it?C~tC_{t}=f_{t} * C_{t-1}+i_{t} * \tilde{C}_{t}Ct?=ft??Ct?1?+it??C~t?
Ct?1C_{t-1}Ct?1?是到高二以及之前的所有記憶,C~t\tilde{C}_{t}C~t?高三這一年學到的所有知識。帶著兩部分應該留下的內容去高考。
在任務中就是希望把新看到對象的性別屬性添加到C,而把舊對象的性別屬性刪除。
第4步,基于細胞狀態得到輸出
首先一個sigmoid層確定細胞狀態的哪個部分的值將輸出:ot=σ(Wo[ht?1,xt]+bo)o_{t}=\sigma\left(W_{o}\left[h_{t-1}, x_{t}\right]+b_{o}\right)ot?=σ(Wo?[ht?1?,xt?]+bo?)
接著用tanh處理細胞狀態,輸出我們確定輸出的那部分,這部分是記憶用于下一時刻幫助做出決策的:ht=ot?tanh?(Ct)h_{t}=o_{t} * \tanh \left(C_{t}\right)ht?=ot??tanh(Ct?)
在語言模型中,既然我當前看到了一個對象,這里可能輸出一個動詞信息,以備下一步需要用到。例如這里可能輸出當前對象是單數還是復數,這樣就知道下一個動詞應該填寫什么形式。
總結:
1:決定老細胞只留下哪部分ft=σ(Wf?[ht?1,xt]+bf)f_{t}=\sigma\left(W_{f} \cdot\left[h_{t-1}, x_{t}\right]+b_{f}\right)ft?=σ(Wf??[ht?1?,xt?]+bf?)
2: 決定新知識應該記住哪部分:it=σ(Wi?[ht?1,xt]+bi)i_{t}=\sigma\left(W_{i} \cdot\left[h_{t-1}, x_{t}\right]+b_{i}\right)it?=σ(Wi??[ht?1?,xt?]+bi?)
新學習到的知識:C~t=tanh?(WC?[ht?1,xt]+bC)\tilde{C}_{t}=\tanh \left(W_{C} \cdot\left[h_{t-1}, x_{t}\right]+b_{C}\right)C~t?=tanh(WC??[ht?1?,xt?]+bC?)
3 更新細胞狀態:Ct=ft?Ct?1+it?C~tC_{t}=f_{t} * C_{t-1}+i_{t} * \tilde{C}_{t}Ct?=ft??Ct?1?+it??C~t?
4 決定要輸出哪部分:ot=σ(Wo[ht?1,xt]+bo)o_{t}=\sigma\left(W_{o}\left[h_{t-1}, x_{t}\right]+b_{o}\right)ot?=σ(Wo?[ht?1?,xt?]+bo?)
產生隱藏狀態的輸出:ht=ot?tanh?(Ct)h_{t}=o_{t} * \tanh \left(C_{t}\right)ht?=ot??tanh(Ct?)
對比普通的RNN,輸出ot=σ(VSt)o_t=\sigma\left(VS_t\right)ot?=σ(VSt?),St=tanh(Uxt+WSt?1)S_t=tanh(Ux_t+WS_{t-1})St?=tanh(Uxt?+WSt?1?),對于記憶StS_tSt?是由之前記憶和新知識共同組成。加入細胞狀態可以選擇忘記一部分老知識和選擇忘記一部分新知識。
在之前的求導過程中?s3?s1=?s3?s2?s2?s1?s1?s0\dfrac{\partial s_3}{\partial s_1}=\dfrac{\partial s_3}{\partial s_2}\dfrac{\partial s_2}{\partial s_1}\dfrac{\partial s_1}{\partial s_0}?s1??s3??=?s2??s3???s1??s2???s0??s1??,現在變為。。。。。
輸出ot=σ(Vht)o_t=\sigma\left(Vh_t\right)ot?=σ(Vht?)
ht=ot?tanh?(Ct)h_{t}=o_{t} * \tanh \left(C_{t}\right)ht?=ot??tanh(Ct?)
Ct=ft?Ct?1+it?C~t=ft?Ct?1+it?tanh?(WC?[ht?1,xt]+bC)C_{t}=f_{t} * C_{t-1}+i_{t} * \tilde{C}_{t}=f_{t} * C_{t-1}+i_t*\tanh \left(W_{C} \cdot\left[h_{t-1}, x_{t}\right]+b_{C}\right)Ct?=ft??Ct?1?+it??C~t?=ft??Ct?1?+it??tanh(WC??[ht?1?,xt?]+bC?)
損失函數不變,還是令t=3,?E3?W=?E3?y^3?y^3?h3?h3?C3?C3?Wc\frac{\partial E_{3}}{\partial W}=\frac{\partial E_{3}}{\partial \hat{y}_{3}} \frac{\partial \hat{y}_{3}}{\partial h_{3}} \frac{\partial h_{3}}{\partial C_3}\frac{\partial C_{3}}{\partial W_c}?W?E3??=?y^?3??E3???h3??y^?3???C3??h3???Wc??C3??
要求?C3?Wc\dfrac{\partial C_3}{\partial W_c}?Wc??C3??,這樣CtC_tCt?與WcW_cWc?有關系,Ct?1C_{t-1}Ct?1?與WcW_cWc?有關系,兩部分相加,對整個函數求導,就是對這兩部分分別求導,再相加。與普通RNN的相乘
?C3?C1=?C3?C2+?C2?C1=?\dfrac{\partial C_3}{\partial C_1}=\dfrac{\partial C_3}{\partial C_2}+\dfrac{\partial C_2}{\partial C_1}=??C1??C3??=?C2??C3??+?C1??C2??=?
5 GRU
GRU是LSTM的變種之一。
GRU做的改變是:
1 將忘記門和輸入門合并成一個門,稱為更新門。
2 細胞狀態和隱藏狀態,也就是上面的C和hth_tht?合并為一個hth_tht?。
這樣GRU的參數就比標準LSTM要少,在很多情況下效果基本一致。
總結
以上是生活随笔為你收集整理的深度学习04-RNN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 交换机tagged与untagged的关
- 下一篇: iOS内存管理(ARC,MRC)