模型优化漫谈:BERT的初始标准差为什么是0.02?
?PaperWeekly 原創(chuàng) ·?作者 | 蘇劍林
單位 |?追一科技
研究方向 | NLP、神經(jīng)網(wǎng)絡(luò)
前幾天在群里大家討論到了“Transformer 如何解決梯度消失”這個(gè)問題,答案有提到殘差的,也有提到 LN(Layer Norm)的。這些是否都是正確答案呢?事實(shí)上這是一個(gè)非常有趣而綜合的問題,它其實(shí)關(guān)聯(lián)到挺多模型細(xì)節(jié),比如“BERT 為什么要 warmup?”、“BERT 的初始化標(biāo)準(zhǔn)差為什么是 0.02?”、“BERT 做 MLM預(yù)測之前為什么還要多加一層 Dense?”,等等。本文就來集中討論一下這些問題。
梯度消失說的是什么意思?
在文章《也來談?wù)?RNN 的梯度消失/爆炸問題》中,我們?cè)懻撨^ RNN 的梯度消失問題。事實(shí)上,一般模型的梯度消失現(xiàn)象也是類似,它指的是(主要是在模型的初始階段)越靠近輸入的層梯度越小,趨于零甚至等于零,而我們主要用的是基于梯度的優(yōu)化器,所以梯度消失意味著我們沒有很好的信號(hào)去調(diào)整優(yōu)化前面的層。
換句話說,前面的層也許幾乎沒有得到更新,一直保持隨機(jī)初始化的狀態(tài);只有比較靠近輸出的層才更新得比較好,但這些層的輸入是前面沒有更新好的層的輸出,所以輸入質(zhì)量可能會(huì)很糟糕(因?yàn)榻?jīng)過了一個(gè)近乎隨機(jī)的變換),因此哪怕后面的層更新好了,總體效果也不好。最終,我們會(huì)觀察到很反直覺的現(xiàn)象:模型越深,效果越差,哪怕訓(xùn)練集都如此。
解決梯度消失的一個(gè)標(biāo)準(zhǔn)方法就是殘差鏈接,正式提出于 ResNet [1] 中。殘差的思想非常簡單直接:你不是擔(dān)心輸入的梯度會(huì)消失嗎?那我直接給它補(bǔ)上一個(gè)梯度為常數(shù)的項(xiàng)不就行了?最簡單地,將模型變成
這樣一來,由于多了一條“直通”路 ,就算 中的 梯度消失了, 的梯度基本上也能得以保留,從而使得深層模型得到有效的訓(xùn)練。
LN真的能緩解梯度消失?
然而,在 BERT 和最初的 Transformer 里邊,使用的是 Post Norm 設(shè)計(jì),它把 Norm 操作加在了殘差之后:
其實(shí)具體的 Norm 方法不大重要,不管是 Batch Norm 還是 Layer Norm,結(jié)論都類似。在文章《淺談 Transformer 的初始化、參數(shù)化與標(biāo)準(zhǔn)化》[2] 中,我們已經(jīng)分析過這種 Norm 結(jié)構(gòu),這里再來重復(fù)一下。
在初始化階段,由于所有參數(shù)都是隨機(jī)初始化的,所以我們可以認(rèn)為 與 是兩個(gè)相互獨(dú)立的隨機(jī)向量,如果假設(shè)它們各自的方差是 1,那么 的方差就是 2,而 操作負(fù)責(zé)將方差重新變?yōu)?1,那么在初始化階段, 操作就相當(dāng)于“除以 ”:
遞歸下去就是:
我們知道,殘差有利于解決梯度消失,但是在 Post Norm 中,殘差這條通道被嚴(yán)重削弱了,越靠近輸入,削弱得越嚴(yán)重,殘差“名存實(shí)亡”。所以說,在 Post Norm 的 BERT 模型中,LN 不僅不能緩解梯度消失,它還是梯度消失的“元兇”之一。
那我們?yōu)槭裁催€要加LN
那么,問題自然就來了:既然 LN 還加劇了梯度消失,那直接去掉它不好嗎?
是可以去掉,但是前面說了, 的方差就是 2 了,殘差越多方差就越大了,所以還是要加一個(gè) Norm 操作,我們可以把它加到每個(gè)模塊的輸入,即變?yōu)?,最后的總輸出再加個(gè) 就行,這就是 Pre Norm 結(jié)構(gòu),這時(shí)候每個(gè)殘差分支是平權(quán)的,而不是像 Post Norm 那樣有指數(shù)衰減趨勢(shì)。
當(dāng)然,也有完全不加 Norm 的,但需要對(duì) 進(jìn)行特殊的初始化,讓它初始輸出更接近于 0,比如 ReZero、Skip Init、Fixup 等,這些在《淺談 Transformer 的初始化、參數(shù)化與標(biāo)準(zhǔn)化》[2] 也都已經(jīng)介紹過了。
但是,拋開這些改進(jìn)不說,Post Norm 就沒有可取之處嗎?難道 Transformer 和 BERT 開始就帶了一個(gè)完全失敗的設(shè)計(jì)?
顯然不大可能。雖然 Post Norm 會(huì)帶來一定的梯度消失問題,但其實(shí)它也有其他方面的好處。最明顯的是,它穩(wěn)定了前向傳播的數(shù)值,并且保持了每個(gè)模塊的一致性。比如 BERT base,我們可以在最后一層接一個(gè) Dense 來分類,也可以取第 6 層接一個(gè) Dense 來分類;但如果你是 Pre Norm 的話,取出中間層之后,你需要自己接一個(gè) LN 然后再接 Dense,否則越靠后的層方差越大,不利于優(yōu)化。
其次,梯度消失也不全是“壞處”,其實(shí)對(duì)于 Finetune 階段來說,它反而是好處。在 Finetune 的時(shí)候,我們通常希望優(yōu)先調(diào)整靠近輸出層的參數(shù),不要過度調(diào)整靠近輸入層的參數(shù),以免嚴(yán)重破壞預(yù)訓(xùn)練效果。而梯度消失意味著越靠近輸入層,其結(jié)果對(duì)最終輸出的影響越弱,這正好是 Finetune 時(shí)所希望的。所以,預(yù)訓(xùn)練好的 Post Norm 模型,往往比 Pre Norm 模型有更好的 Finetune 效果,這我們?cè)凇禦ealFormer:把殘差轉(zhuǎn)移到 Attention 矩陣上面去》也提到過。
我們真的擔(dān)心梯度消失嗎?
其實(shí),最關(guān)鍵的原因是,在當(dāng)前的各種自適應(yīng)優(yōu)化技術(shù)下,我們已經(jīng)不大擔(dān)心梯度消失問題了。
這是因?yàn)?#xff0c;當(dāng)前 NLP 中主流的優(yōu)化器是 Adam 及其變種。對(duì)于 Adam 來說,由于包含了動(dòng)量和二階矩校正,所以近似來看,它的更新量大致上為
可以看到,分子分母是都是同量綱的,因此分?jǐn)?shù)結(jié)果其實(shí)就是 的量級(jí),而更新量就是 量級(jí)。也就是說,理論上只要梯度的絕對(duì)值大于隨機(jī)誤差,那么對(duì)應(yīng)的參數(shù)都會(huì)有常數(shù)量級(jí)的更新量;這跟 SGD 不一樣,SGD 的更新量是正比于梯度的,只要梯度小,更新量也會(huì)很小,如果梯度過小,那么參數(shù)幾乎會(huì)沒被更新。
所以,Post Norm 的殘差雖然被嚴(yán)重削弱,但是在 base、large 級(jí)別的模型中,它還不至于削弱到小于隨機(jī)誤差的地步,因此配合 Adam 等優(yōu)化器,它還是可以得到有效更新的,也就有可能成功訓(xùn)練了。當(dāng)然,只是有可能,事實(shí)上越深的 Post Norm 模型確實(shí)越難訓(xùn)練,比如要仔細(xì)調(diào)節(jié)學(xué)習(xí)率和 Warmup 等。
Warmup是怎樣起作用的?
大家可能已經(jīng)聽說過,Warmup 是Transformer訓(xùn)練的關(guān)鍵步驟,沒有它可能不收斂,或者收斂到比較糟糕的位置。為什么會(huì)這樣呢?不是說有了Adam就不怕梯度消失了嗎?
要注意的是,Adam 解決的是梯度消失帶來的參數(shù)更新量過小問題,也就是說,不管梯度消失與否,更新量都不會(huì)過小。但對(duì)于 Post Norm 結(jié)構(gòu)的模型來說,梯度消失依然存在,只不過它的意義變了。根據(jù)泰勒展開式:
也就是說增量 是正比于梯度的,換句話說,梯度衡量了輸出對(duì)輸入的依賴程度。如果梯度消失,那么意味著模型的輸出對(duì)輸入的依賴變?nèi)趿恕?/p>
Warmup 是在訓(xùn)練開始階段,將學(xué)習(xí)率從 0 緩增到指定大小,而不是一開始從指定大小訓(xùn)練。如果不進(jìn)行 Wamrup,那么模型一開始就快速地學(xué)習(xí),由于梯度消失,模型對(duì)越靠后的層越敏感,也就是越靠后的層學(xué)習(xí)得越快,然后后面的層是以前面的層的輸出為輸入的,前面的層根本就沒學(xué)好,所以后面的層雖然學(xué)得快,但卻是建立在糟糕的輸入基礎(chǔ)上的。
很快地,后面的層以糟糕的輸入為基礎(chǔ)到達(dá)了一個(gè)糟糕的局部最優(yōu)點(diǎn),此時(shí)它的學(xué)習(xí)開始放緩(因?yàn)橐呀?jīng)到達(dá)了它認(rèn)為的最優(yōu)點(diǎn)附近),同時(shí)反向傳播給前面層的梯度信號(hào)進(jìn)一步變?nèi)?#xff0c;這就導(dǎo)致了前面的層的梯度變得不準(zhǔn)。但我們說過,Adam 的更新量是常數(shù)量級(jí)的,梯度不準(zhǔn),但更新量依然是數(shù)量級(jí),意味著可能就是一個(gè)常數(shù)量級(jí)的隨機(jī)噪聲了,于是學(xué)習(xí)方向開始不合理,前面的輸出開始崩盤,導(dǎo)致后面的層也一并崩盤。
所以,如果 Post Norm 結(jié)構(gòu)的模型不進(jìn)行 Wamrup,我們能觀察到的現(xiàn)象往往是:loss 快速收斂到一個(gè)常數(shù)附近,然后再訓(xùn)練一段時(shí)間,loss 開始發(fā)散,直至 NAN。如果進(jìn)行 Wamrup,那么留給模型足夠多的時(shí)間進(jìn)行“預(yù)熱”,在這個(gè)過程中,主要是抑制了后面的層的學(xué)習(xí)速度,并且給了前面的層更多的優(yōu)化時(shí)間,以促進(jìn)每個(gè)層的同步優(yōu)化。
這里的討論前提是梯度消失,如果是 Pre Norm 之類的結(jié)果,沒有明顯的梯度消失現(xiàn)象,那么不加 Warmup 往往也可以成功訓(xùn)練。
初始標(biāo)準(zhǔn)差為什么是0.02?
喜歡扣細(xì)節(jié)的同學(xué)會(huì)留意到,BERT 默認(rèn)的初始化方法是標(biāo)準(zhǔn)差為 0.02 的截?cái)嗾龖B(tài)分布,在《淺談 Transformer 的初始化、參數(shù)化與標(biāo)準(zhǔn)化》[2] 我們也提過,由于是截?cái)嗾龖B(tài)分布,所以實(shí)際標(biāo)準(zhǔn)差會(huì)更小,大約是 。這個(gè)標(biāo)準(zhǔn)差是大還是小呢?對(duì)于 Xavier 初始化來說,一個(gè) 的矩陣應(yīng)該用 的方差初始化,而 BERT base 的 為 768,算出來的標(biāo)準(zhǔn)差是 。這就意味著,這個(gè)初始化標(biāo)準(zhǔn)差是明顯偏小的,大約只有常見初始化標(biāo)準(zhǔn)差的一半。
為什么 BERT 要用偏小的標(biāo)準(zhǔn)差初始化呢?事實(shí)上,這還是跟 Post Norm 設(shè)計(jì)有關(guān),偏小的標(biāo)準(zhǔn)差會(huì)導(dǎo)致函數(shù)的輸出整體偏小,從而使得 Post Norm 設(shè)計(jì)在初始化階段更接近于恒等函數(shù),從而更利于優(yōu)化。具體來說,按照前面的假設(shè),如果 的方差是 , 的方差是 ,那么初始化階段, 操作就相當(dāng)于除以 。如果 比較小,那么殘差中的“直路”權(quán)重就越接近于 1,那么模型初始階段就越接近一個(gè)恒等函數(shù),就越不容易梯度消失。
正所謂“我們不怕梯度消失,但我們也不希望梯度消失”,簡單地將初始化標(biāo)注差設(shè)小一點(diǎn),就可以使得 變小一點(diǎn),從而在保持 Post Norm 的同時(shí)緩解一下梯度消失,何樂而不為?那能不能設(shè)置得更小甚至全零?一般來說初始化過小會(huì)喪失多樣性,縮小了模型的試錯(cuò)空間,也會(huì)帶來負(fù)面效果。綜合來看,縮小到標(biāo)準(zhǔn)的 1/2,是一個(gè)比較靠譜的選擇了。
當(dāng)然,也確實(shí)有人喜歡挑戰(zhàn)極限的,最近筆者也看到了一篇文章,試圖讓整個(gè)模型用幾乎全零的初始化,還訓(xùn)練出了不錯(cuò)的效果,大家有興趣可以讀讀,文章為《ZerO Initialization: Initializing Residual Networks with only Zeros and Ones》[3]。
為什么MLM要多加Dense?
最后,是關(guān)于 BERT 的 MLM 模型的一個(gè)細(xì)節(jié),就是 BERT 在做 MLM 的概率預(yù)測之前,還要多接一個(gè) Dense 層和 LN 層,這是為什么呢?不接不行嗎?
之前看到過的答案大致上是覺得,越靠近輸出層的,越是依賴任務(wù)的(Task-Specified),我們多接一個(gè) Dense 層,希望這個(gè) Dense 層是 MLM-Specified 的,然后下游任務(wù)微調(diào)的時(shí)候就不是 MLM-Specified 的,所以把它去掉。這個(gè)解釋看上去有點(diǎn)合理,但總感覺有點(diǎn)玄學(xué),畢竟 Task-Specified 這種東西不大好定量分析。
這里筆者給出另外一個(gè)更具體的解釋,事實(shí)上它還是跟 BERT 用了 0.02 的標(biāo)準(zhǔn)差初始化直接相關(guān)。剛才我們說了,這個(gè)初始化是偏小的,如果我們不額外加 Dense 就乘上 Embedding 預(yù)測概率分布,那么得到的分布就過于均勻了(Softmax 之前,每個(gè) logit 都接近于 0),于是模型就想著要把數(shù)值放大。
現(xiàn)在模型有兩個(gè)選擇:第一,放大 Embedding 層的數(shù)值,但是 Embedding 層的更新是稀疏的,一個(gè)個(gè)放大太麻煩;第二,就是放大輸入,我們知道 BERT 編碼器最后一層是 LN,LN 最后有個(gè)初始化為 1 的 gamma 參數(shù),直接將那個(gè)參數(shù)放大就好。
模型優(yōu)化使用的是梯度下降,我們知道它會(huì)選擇最快的路徑,顯然是第二個(gè)選擇更快,所以模型會(huì)優(yōu)先走第二條路。這就導(dǎo)致了一個(gè)現(xiàn)象:最后一個(gè) LN 層的 gamma 值會(huì)偏大。如果預(yù)測 MLM 概率分布之前不加一個(gè) Dense+LN,那么? BERT 編碼器的最后一層的 LN 的 gamma 值會(huì)偏大,導(dǎo)致最后一層的方差會(huì)比其他層的明顯大,顯然不夠優(yōu)雅;而多加了一個(gè) Dense+LN 后,偏大的 gamma 就轉(zhuǎn)移到了新增的 LN 上去了,而編碼器的每一層則保持了一致性。
事實(shí)上,讀者可以自己去觀察一下 BERT 每個(gè) LN 層的 gamma 值,就會(huì)發(fā)現(xiàn)確實(shí)是最后一個(gè) LN 層的 gamma 值是會(huì)明顯偏大的,這就驗(yàn)證了我們的猜測~
希望大家多多海涵批評(píng)斧正
本文試圖回答了 Transformer、BERT 的模型優(yōu)化相關(guān)的幾個(gè)問題,有一些是筆者在自己的預(yù)訓(xùn)練工作中發(fā)現(xiàn)的結(jié)果,有一些則是結(jié)合自己的經(jīng)驗(yàn)所做的直觀想象。不管怎樣,算是分享一個(gè)參考答案吧,如果有不當(dāng)?shù)牡胤?#xff0c;請(qǐng)大家海涵,也請(qǐng)各位批評(píng)斧正。
參考文獻(xiàn)
[1] https://arxiv.org/abs/1512.03385
[2] https://kexue.fm/archives/8620
[3] https://arxiv.org/abs/2110.12661
特別鳴謝
感謝 TCCI 天橋腦科學(xué)研究院對(duì)于 PaperWeekly 的支持。TCCI 關(guān)注大腦探知、大腦功能和大腦健康。
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)術(shù)熱點(diǎn)剖析、科研心得或競賽經(jīng)驗(yàn)講解等。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來。
📝?稿件基本要求:
? 文章確系個(gè)人原創(chuàng)作品,未曾在公開渠道發(fā)表,如為其他平臺(tái)已發(fā)表或待發(fā)表的文章,請(qǐng)明確標(biāo)注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發(fā)送,要求圖片清晰,無版權(quán)問題
? PaperWeekly 尊重原作者署名權(quán),并將為每篇被采納的原創(chuàng)首發(fā)稿件,提供業(yè)內(nèi)具有競爭力稿酬,具體依據(jù)文章閱讀量和文章質(zhì)量階梯制結(jié)算
📬?投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請(qǐng)備注即時(shí)聯(lián)系方式(微信),以便我們?cè)诟寮x用的第一時(shí)間聯(lián)系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
🔍
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
·
總結(jié)
以上是生活随笔為你收集整理的模型优化漫谈:BERT的初始标准差为什么是0.02?的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: NeurlPS 2021论文预讲会议题全
- 下一篇: 虎式h1h2分别装备哪两种88毫米火炮。