bert剪枝系列——Are Sixteen Heads Really Better than One?
1,概述
剪枝可以分為兩種:一種是無序的剪枝,比如將權(quán)重中一些值置為0,這種也稱為稀疏化,在實際的應(yīng)用上這種剪枝基本沒有意義,因為它只能壓縮模型的大小,但很多時候做不到模型推斷加速,而在當(dāng)今的移動設(shè)備上更多的關(guān)注的是系統(tǒng)的實時相應(yīng),也就是模型的推斷速度。另一種是結(jié)構(gòu)化的剪枝,比如卷積中對channel的剪枝,這種不僅可以降低模型的大小,還可以提升模型的推斷速度。剪枝之前在卷積上應(yīng)用較多,而隨著bert之類的預(yù)訓(xùn)練模型的出現(xiàn),這一類模型通常比較大,且推斷速度較慢。例如bert在文本分類的任務(wù)上,128的序列長度,其推斷速度都只有80ms左右,這還只是單個模型,而一個大的系統(tǒng),往往是有多個模型組成的。因此bert要想在工業(yè)界,尤其是移動端落地,是極度需要模型壓縮的。
2,具體方法
看完這篇論文之后,更多的感覺是這篇論文并沒有在剪枝上有太多的貢獻(xiàn),更像是對multi head中head的數(shù)量做了一個實驗性的工作,探索了在multi head中并不是所有的head都需要,有很多head提取的信息對最終的結(jié)果并沒有什么影響,是冗余存在的。
本論文在探討在test階段,去掉一部分head是否會影響模型的性能,得到的結(jié)論是大多數(shù)都不會,而且部分還會提升性能,作者給出了三種實驗方法來證明這一點:
1,每次去掉一層中一個head,測試模型的性能
2,每次去掉一層中剩余的層,只保存一個head,測試模型的性能
3,通過梯度來判斷每個head的重要性,然后去掉一部分不重要的head,測試模型的性能
為了實現(xiàn)上述的實驗,作者對multi head的計算做了一些修改,修改后的公式如下:
在這里引入了一個系數(shù)$zeta_h$,該值的取值為0或1,它的作用是用來mask不重要的head。在訓(xùn)練時保持為1,到test的時候?qū)Σ糠謍ead mask掉。
作者在基于transformer的機(jī)器翻譯模型上和基于bert的NLI任務(wù)上做了實驗,我們來看看上面三個實驗的結(jié)果
Ablating One Head
去掉一個head,作者給出了實驗結(jié)果如下:
從上面的圖中可以看到大多數(shù)head去掉之后的模型分?jǐn)?shù)還基本分布在baseline附近,從作者給的表格數(shù)據(jù)看會更加的清晰:
上面給出的是機(jī)器翻譯的表格數(shù)據(jù),藍(lán)色的值表示性能增加,紅色的值表示性能下降,大多數(shù)情況下性能是增加的,只有少部分性能會有所下降,只有極少部分性能會下降的比較多。
Ablating All Heads but One
當(dāng)去掉一層中的其余head只保留一個head時,我們來看下模型的結(jié)果,這回作者給出了一個離散圖:
同樣的,大多數(shù)情況下的性能都分布在baseline附近,同樣看看表格會更清晰:
從上面來看除了機(jī)器翻譯中的encoder-decoder之間的attention的最后一層會出現(xiàn)性能明顯的下降,其他大多數(shù)情況都還好,甚至有的情況下性能反而上升。
上面兩種實驗都有一個共同的弊端,就是每次實驗只能對一層做head的mask,但實際過程中所有層的head都有可能會被去除,且至于去除哪些還和層與層之間的依賴性有關(guān),因此第三種方法可以來改善這個問題。
Head Importance Score for Pruning
在這里作者引入了梯度來衡量head的重要性,首先給出一個公式如下:
上面公式是對mask系數(shù)的偏導(dǎo),我們知道偏導(dǎo)的值的大小可以衡量這個維度上對損失的影響大小,在這里作者對偏導(dǎo)取了個絕對值,避免在求期望的時候正負(fù)抵消,因為無論是正值還是負(fù)值,只要絕對值比較大,就可以衡量偏導(dǎo)對損失的影響是比較大的,這里的期望是對所有樣本X的,因為單個batch是存在誤差的,因此對全量樣本計算的偏導(dǎo)求均值。
對上面的公式做一個鏈?zhǔn)睫D(zhuǎn)換,可以得到:
這樣我們就可以用這個對head的期望梯度值來衡量其重要性,然后按百分比去除head,得到的結(jié)果如下:
上面圖中的實驗是通過梯度來進(jìn)行剪枝的,虛線是通過第一種方法中的分?jǐn)?shù)來衡量head的重要性進(jìn)行剪枝的,可以看到基于梯度的效果還是很明顯的,但是剪枝范圍也是有限的,超過這個范圍之后,性能會急劇下降。
作者還測了下剪枝后模型的推斷速度,個人感覺這個推斷速度的減小真的是毫無意義:
如上圖所示,只有在batch達(dá)到16的時候才有比較明顯的速度提升,但是大多數(shù)線上運行的時候都是batch為1的。不過也不能就此下定論說減少head的數(shù)量是起不到加速效果的,個人感覺作者在這里測推斷速度的時候是存在一些問題的:作者是先訓(xùn)練,后剪枝,但剪枝之后沒有再訓(xùn)練,這也就意味著這些head仍然存在,只是將不需要的head前面的mask系數(shù)置為0而已。為什么做出這樣的認(rèn)定呢?因為在實際的multi head設(shè)計中,我們是要保證每個head得到的詞向量拼接在一起等于原始的詞向量,因為后面要進(jìn)入到前向?qū)樱仨毐3志S度一致,我猜這里作者可能是將mask掉的head得到的向量置為0,這樣這些值在下一層計算self-attention就沒有意義了,至于為什么還是有加速,原因不明。以上個人猜測。
此外單純得減少head的數(shù)量好像對加速意義不大,只有配合減小embedding size才有意義,否則計算復(fù)雜度基本一致,因為我們在做multi-attention時映射到不同子空間時,實際上是一個大的矩陣映射,這個大的矩陣的維度取決于embedding size,映射完之后再分割成多個而已。從計算上來看self-attention是耗時的,因為減少embedding size,減小序列長度都可以極大的提速(減小序列長度還會影響到前向?qū)拥乃俣龋?/p>
總結(jié)
以上是生活随笔為你收集整理的bert剪枝系列——Are Sixteen Heads Really Better than One?的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 网易校园招聘历年经典面试题汇总:C++研
- 下一篇: notepad++ 文本文件内容丢失恢复