也来谈谈RNN的梯度消失/爆炸问题
?PaperWeekly 原創(chuàng) ·?作者|蘇劍林
單位|追一科技
研究方向|NLP、神經(jīng)網(wǎng)絡
盡管 Transformer 類的模型已經(jīng)攻占了 NLP 的多數(shù)領域,但諸如 LSTM、GRU?之類的 RNN?模型依然在某些場景下有它的獨特價值,所以 RNN 依然是值得我們好好學習的模型。而于 RNN 梯度的相關分析,則是一個從優(yōu)化角度思考分析模型的優(yōu)秀例子,值得大家仔細琢磨理解。君不見,諸如“LSTM 為什么能解決梯度消失/爆炸”等問題依然是目前流行的面試題之一。
▲經(jīng)典的LSTM
關于此類問題,已有不少網(wǎng)友做出過回答,然而筆者查找了一些文章(包括知乎上的部分回答、專欄以及經(jīng)典的英文博客),發(fā)現(xiàn)沒有找到比較好的答案:有些推導記號本身就混亂不堪,有些論述過程沒有突出重點,整體而言感覺不夠清晰自洽。為此,筆者也嘗試給出自己的理解,供大家參考。
RNN及其梯度
RNN 的統(tǒng)一定義為:
其中 是每一步的輸出,它由當前輸入 和前一時刻輸出 共同決定,而 則是可訓練參數(shù)。在做最基本的分析時,我們可以假設 都是一維的,這可以讓我們獲得最直觀的理解,并且其結果對高維情形仍有參考價值。之所以要考慮梯度,是因為我們目前主流的優(yōu)化器還是梯度下降及其變種,因此要求我們定義的模型有一個比較合理的梯度。我們可以求得:
可以看到,其實 RNN 的梯度也是一個 RNN,當前時刻梯度 是前一時刻梯度 與當前運算梯度 的函數(shù)。同時,從上式我們就可以看出,其實梯度消失或者梯度爆炸現(xiàn)象幾乎是必然存在的:
當 時,意味著歷史的梯度信息是衰減的,因此步數(shù)多了梯度必然消失(好比 );當 ,因為這歷史的梯度信息逐步增強,因此步數(shù)多了梯度必然爆炸(好比 )。總不可能一直 吧?當然,也有可能有些時刻大于 1,有些時刻小于 1,最終穩(wěn)定在 1 附近,但這樣概率很小,需要很精巧地設計模型才行。
所以步數(shù)多了,梯度消失或爆炸幾乎都是不可避免的,我們只能對于有限的步數(shù)去緩解這個問題。
消失還是爆炸?
說到這里,我們還沒說清楚一個問題:什么是 RNN 的梯度消失/爆炸?梯度爆炸好理解,就是梯度數(shù)值發(fā)散,甚至慢慢就 NaN 了;那梯度消失就是梯度變成零嗎?并不是,我們剛剛說梯度消失是 一直小于 1,歷史梯度不斷衰減,但不意味著總的梯度就為 0 了,具體來說,一直迭代下去,我們有:
顯然,其實只要 不為 0,那么總梯度為 0 的概率其實是很小的;但是一直迭代下去的話,那么 這一項前面的稀疏就是 t-1 項的連乘 ,如果它們的絕對值都小于 1,那么結果就會趨于 0,這樣一來, 幾乎就沒有包含最初的梯度 的信息了。
這才是 RNN 中梯度消失的含義:距離當前時間步越長,那么其反饋的梯度信號越不顯著,最后可能完全沒有起作用,這就意味著 RNN 對長距離語義的捕捉能力失效了。
說白了,你優(yōu)化過程都跟長距離的反饋沒關系,怎么能保證學習出來的模型能有效捕捉長距離呢?
幾個數(shù)學公式
上面的文字都是一般性的分析,接下來我們具體 RNN 具體分析。不過在此之前,我們需要回顧幾條數(shù)學公式,后面的推導中我們將多次運用到這幾條公式:
其中 是 sigmoid 函數(shù)。這幾條公式其實就是說了這么一件事: 和 基本上是等價的,它們的導數(shù)均可以用它們自身來表示。
簡單RNN分析
首先登場的是比較原始的簡單 RNN(有時候我們確實直接稱它為 SimpleRNN),它的公式為:
其中 W,U,b 是待優(yōu)化參數(shù)。看到這里很自然就能提出第一個疑問:為什么激活函數(shù)用 而不是更流行的 ?這是個好問題,我們很快就會回答它。
從上面的討論中我們已經(jīng)知道,梯度消失還是爆炸主要取決于 ,所以我們計算:
由于我們無法確定 U 的范圍,因此 可能小于 1 也可能大于 1,梯度消失/爆炸的風險是存在的。但有意思的是,如果 |U| 很大,那么相應地 就會很接近 1 或 -1,這樣 反而會小,事實上可以嚴格證明:如果固定 ,那么 作為 U 的函數(shù)是有界的,也就是說不管 U 等于什么,它都不超過一個固定的常數(shù)。
這樣一來,我們就能回答為什么激活函數(shù)要用 了,因為激活函數(shù)用 后,對應的梯度 是有界的,雖然這個界未必是 1,但一個有界的量不超過 1 的概率總高于無界的量,因此梯度爆炸的風險更低。相比之下,如果用 激活的話,它在正半軸的導數(shù)恒為 1,此時 是無界的,梯度爆炸風險更高。
所以,RNN 用 而不是 的主要目的就是緩解梯度爆炸風險。當然,這個緩解是相對的,用了 依然有爆炸的可能性。事實上,處理梯度爆炸的最根本方法是參數(shù)裁剪或梯度裁剪,說白了,就是我人為地把 U 給裁剪到 [-1,1] 內,那不就可以保證梯度不爆了嗎?
當然,又有讀者會問,既然裁剪可以解決問題,那么是不是可以用 了?確實是這樣子,配合良好的初始化方法和參數(shù)/梯度裁剪方案, 版的 RNN 也可以訓練好,但是我們還是愿意用 ,這還是因為它對應的 有界,要裁剪也不用裁剪得太厲害,模型的擬合能力可能會更好。
LSTM的結果
當然,裁剪的方式雖然也能 work,但終究是無奈之舉,況且裁剪也只能解決梯度爆炸問題,解決不了梯度消失,如果能從模型設計上解決這個問題,那么自然是最好的。傳說中的 LSTM 就是這樣的一種設計,真相是否如此?我們馬上來分析一下。
LSTM 的更新公式比較復雜,它是:
我們可以像上面一樣計算 ,但從 可以看出分析 就等價于分析 ,而計算 顯得更加簡單一些,因此我們往這個方向走。
同樣地,我們先只關心 1 維的情形,這時候根據(jù)求導公式,我們有:
右端第一項 ,也就是我們所說的“遺忘門”,從下面的論述我們可以知道一般情況下其余三項都是次要項,因此 是“主項”,由于 在 0~1 之間,因此就意味著梯度爆炸的風險將會很小,至于會不會梯度消失,取決于 是否接近于 1。
但非常碰巧的是,這里有個相當自洽的結論:如果我們的任務比較依賴于歷史信息,那么 就會接近于 1,這時候歷史的梯度信息也正好不容易消失;如果 很接近于 0,那么就說明我們的任務不依賴于歷史信息,這時候就算梯度消失也無妨了。
所以,現(xiàn)在的關鍵就是看“其余三項都是次要項”這個結論能否成立。后面的三項都是“一項乘以另一項的偏導”的形式,而且求偏導的項都是 或 激活,前面在回顧數(shù)學公式的時候說了 和 基本上是等價的,因此后面三項是類似的,分析了其中一項就相當于分析了其余兩項。以第二項為例,代入 ,可以算得:
注意到 ,都是在 0~1 之間,也可以證明 ,因此它也在 - 1~1 之間。所以說白了 就相當于 1 個 乘上 4 個門,結果會變得更加小,所以只要初始化不是很糟糕,那么它都會被壓縮得相當小,因此占不到主導作用。
跟簡單 RNN 的梯度(6)相比,它也多出了 3 個門,說直觀一點那就是:1 個門我壓不垮你,多來幾個門還不行么?
剩下兩項的結論也是類似的:
所以,后面三項的梯度帶有更多的“門”,一般而言乘起來后會被壓縮的更厲害,因此占主導的項還是 , 在 0~1 之間這個特性決定了它梯度爆炸的風險很小,同時 表明了模型對歷史信息的依賴性,也正好是歷史梯度的保留程度,兩者相互自洽,所以 LSTM 也能較好地緩解梯度消失問題。
因此,LSTM 同時較好地緩解了梯度消失/爆炸問題,現(xiàn)在我們訓練 LSTM 時,多數(shù)情況下只需要直接調用 Adam 等自適應學習率優(yōu)化器,不需要人為對梯度做什么調整了。
當然,這些結果都是“概論”,你非要構造一個會梯度消失/爆炸的 LSTM 來,那也是能構造出來的。此外,就算 LSTM 能緩解這兩個問題,也是在一定步數(shù)內,如果你的序列很長,比如幾千上萬步,那么該消失的還會消失。畢竟單靠一個向量,也緩存不了那么多信息啊~
順便看看GRU
在文章結束之前,我們順便對 LSTM 的強力競爭對手 GRU 也做一個分析。GRU 的運算過程為:
還有個更極端的版本是將 合成一個:
不管是哪一個,我們發(fā)現(xiàn)它在算 的時候, 都是先乘個 變成 ,不知道讀者是否困惑過這一點?直接用 不是更簡潔更符合直覺嗎?
首先我們觀察到,而 一般全零初始化, 則因為 激活,因此結果必然在 -1~1 之間,所以作為 與 的加權平均的 也一直保持在 -1~1 之間,因此 本身就有類似門的作用。這跟LSTM的 不一樣,理論上 是有可能發(fā)散的。了解到這一點后,我們再去求導:
其實結果跟 LSTM 的類似,主導項應該是 ,但剩下的項比 LSTM 對應的項少了 1 個門,因此它們的量級可能更大,相對于 LSTM 的梯度其實更不穩(wěn)定,特別是 這步操作,雖然給最后一項引入了多一個門 ,但也同時引入了多一項 ,是好是歹很難說。總體相對而言,感覺 GRU 應該會更不穩(wěn)定,比 LSTM 更依賴于好的初始化方式。
針對上述分析結果,個人認為如果沿用 GRU 的思想,又需要簡化 LSTM 并且保持 LSTM 對梯度的友好性,更好的做法是把 放到最后:
當然,這樣需要多緩存一個變量,帶來額外的顯存消耗了。
文章總結概述
本文討論了 RNN 的梯度消失/爆炸問題,主要是從梯度函數(shù)的有界性、門控數(shù)目的多少來較為明確地討論 RNN、LSTM、GRU 等模型的梯度流情況,以確定其中梯度消失/爆炸風險的大小。本文屬于閉門造車之作,如有錯漏,請讀者海涵并斧正。
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質內容以更短路徑到達讀者群體,縮短讀者尋找優(yōu)質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優(yōu)質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
?????來稿標準:
? 稿件確系個人原創(chuàng)作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發(fā),請在投稿時提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認每篇文章都是首發(fā),均會添加“原創(chuàng)”標志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發(fā)送?
? 請留下即時聯(lián)系方式(微信或手機),以便我們在編輯發(fā)布時和作者溝通
????
現(xiàn)在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結
以上是生活随笔為你收集整理的也来谈谈RNN的梯度消失/爆炸问题的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: NeurIPS 2020 | 利用像素级
- 下一篇: 2018年大额存款利息多少