RNN训练方法介绍-BPTT
生活随笔
收集整理的這篇文章主要介紹了
RNN训练方法介绍-BPTT
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
url:http://blog.csdn.net/sysstc/article/details/75333008
Training RNN——BPTT
由于RNN和時間序列有關(guān),因此我們不能通過Backpropagation來調(diào)參,我們使用的是Backpropagation through time(BPTT)
回顧Backpropagation
?
Backpropagation through Time(BPTT)
我們可以將RNN變成如下形式:?
參照之前的BP我們發(fā)現(xiàn),反向傳播其實就是一個梯度* 矩陣 * 激活函數(shù)的微分(放大器)。由于an-1和xn同時會影響到an,而an-1又會被an-2和xn-1影響,并且依次傳遞下去,a1的值會被x1和memory cell initial影響。因此,我們可以把RNN看成一個非常深的DNN,將input看成:Init,x1,x2,…,xn,output是yn,也就是如果n=100,那么就有100個hidden layer。
- 如何計算Cn這項的gradient呢?
BPTT同樣等于Backward pass + forward pass,forward pass 可以直接當做是一個DNN來計算,而backward pass可以看成如下所示,通過一個hidden layer,就相當于乘以一個個的放大器(activation function)?
如何更新參數(shù)呢,我們看到上面黃色的箭頭都是相同的weights,而上面藍色的箭頭也都是相同的weights。因此我們根據(jù)如下方式修改。?
- 如何計算所有的gradient呢?其實也是一個forward pass和一個backward pass?
RNN中train可能遇到的問題
- DNN:gradient vanish是因為backward pass時,每經(jīng)過一層都要經(jīng)歷一個activation function的微分,那這時如果我們使用的activation function是sigmoid,由于sigmoid的微分最大值為1/4,那么error signal就會越來越小,最后可能會出現(xiàn)梯度消失的情況。我們之前提到的解決方式是將activation function由sigmoid改成relu。
- RNN:?
- 從一個小案例分析:參考下面圖,我們要計算gradient,下面是一個RNN,Input都是1,output都是1,hidden layer 里面就一個神經(jīng)元,這里我們將hidden layer設(shè)置為一個linear。那么由于是一個linear,是不是我們就不會出現(xiàn)gradient vanish的情況呢?答案是否定的:?
我們發(fā)現(xiàn)w的不同,yn變化忽大忽小。這樣就會變得非常難處理。就式子和函數(shù)圖像來分析原因:?
這種情況就造成我們在算gradient的時候,大多都是一些極端值。?
那么從這個角度來看我們發(fā)現(xiàn)如果你這時候?qū)ctivation function設(shè)置為sigmoid或tanh,這種微分值小于1的activation function,反而能在某種程度上保護rnn,而用relu則不會削減這些大爆炸的情況。所以在處理RNN的時候還是應(yīng)該采用tanh或者sigmoid - 從一般情況來分析,今天為了簡化分析,我們將activation function都認為是linear activation function,底下的error signal 沒有考慮activation function,如果error signal對應(yīng)到的Wn-1是小于1的話,那么值就會變得很小,如果error signal對應(yīng)到的Wn-1大于1的話,那么值就會變得很大。?
- 從一個小案例分析:參考下面圖,我們要計算gradient,下面是一個RNN,Input都是1,output都是1,hidden layer 里面就一個神經(jīng)元,這里我們將hidden layer設(shè)置為一個linear。那么由于是一個linear,是不是我們就不會出現(xiàn)gradient vanish的情況呢?答案是否定的:?
- Clipped Gradient:設(shè)置一個threshold,clip(x,min,max)
- NAG:Momentum進化版。?
- Momentum是一個模范物體運動的方法,update方向取決于所在位置的Gradient的反方向+上次的movement,這樣就可能會照成數(shù)值波動。
- NAG的update方向取決于從現(xiàn)在所在的位置沿著movement再走一步所在的Gradient的反方向+上一次的movement。這樣可能會避免產(chǎn)生震蕩的情形?
- RMSProp:Adagrad的進化版?
- Adagrad:除去這個參數(shù)過去所有算出來的gradient的平方和再開根號(即對二次微分的估算)。這樣就可以動態(tài)調(diào)整learning rate
- RMSProp:?
過去的gradient會乘上α(0<α<1),那么越過去的gradient我們考慮的權(quán)重就會更小,這樣二次微分對同一個參數(shù)也會產(chǎn)生變化。?
LSTM解決Gradient vanishing problem
- LSTM的forward pass過程:?
- 只需要把data帶進去就可以求每個neural的output
- LSTM的backward pass過程:?
- 遇到了一個”+”,我們可以把這個看成一個activation function。input:a/b,output:a+b,微分值為1。
- 遇到了一個”x”,我們也把這個看成一個activation function。input:a/b,output:axb,微分值為b/a,底下圖片上的點乘代表的是element-wise。?
?
如果我們今天從yt+1做bptt,假設(shè)forget gate是開啟狀態(tài),那么我們發(fā)現(xiàn),從yt+1開始的error signal在通過灰色matrix時會乘以一個W的transposition,在通過activation function時可能會乘以一個小于1的數(shù)字之外,error signal就是一路暢行無阻的,就會一路保持constant的error signal,這就是Constant Error Carrousel(CEC)。這樣error signal一路保持constant的值的好處有什么呢?如果今天你要update黃色箭頭里的element,那么因為紅色箭頭不會隨時間衰減太多,所以藍色的值也不會太小,那么拿來update黃色箭頭的error signal不會太小。?
?
因此LSTM可以解決gradient vanish的問題,但LSTM不能處理gradient explode的問題因為error signal不只走藍色的箭頭,還有可能走綠色的箭頭,綠色的箭頭再走到藍色箭頭的部分就會一直乘以一個W的transposition,如果W的transposition是小的則沒關(guān)系,因為error signal是比較大的(之前提到過,在error signal的流動過程中它減小的很少)。但如果W的transposition是大的,這時候不斷的乘以W的transposition,值就會變得越來越大,就可能會導(dǎo)致gradient explode的情況。?
- 對于RNN,每一個step我們都會把hidden Layer的output寫到memory里面去,所以memory里面的值每次都會被完全修正,過去的東西其實一點都沒有存留下來。這就會導(dǎo)致當你修改模型中某個參數(shù)的時候,可能會造成很大的變化,也有可能沒有變化。
- 對于LSTM,如果沒有forget gate,那么我們過去得memory都會保存下來,因為我們用的是”+”。如果我們的forget gate沒有被開啟,那么這個memory就會永遠的存在。這樣如果我們的memory都會一直留下來,那么在修改某個參數(shù)時造成某個改變時,這個改變就不會消失,這樣我們的gradient就不會消失,這樣就可以保證我們的gradient不會特別小。但無法保證gradient不會explode。
RNN的變形
Better Initialization
紅色的就是initialized with identity matrix+ReLU?
總結(jié):
在training RNN的時候可能會遇到Gradient vanish和Gradient explode的問題。?
這里的解決方法:
- 設(shè)置一個threshold(min,max)
- 優(yōu)化技術(shù)?
- NAG
- RMSprop
- LSTM(或者其他變形)
- 更好的初始化,hidden layer之間的weight初始化用identity matrix,activation function用ReLU.
總結(jié)
以上是生活随笔為你收集整理的RNN训练方法介绍-BPTT的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 药学【7】
- 下一篇: 激光测距误差对激光脚点定位的影响