关于知识蒸馏,你想知道的都在这里!
"蒸餾",一個(gè)化學(xué)用語,在不同的沸點(diǎn)下提取出不同的成分。知識蒸餾就是指一個(gè)很大很復(fù)雜的模型,有著非常好的效果和泛化能力,這是缺乏表達(dá)能力的小模型所不能擁有的。因此從大模型學(xué)到的知識用于指導(dǎo)小模型,使得小模型具有大模型的泛化能力,并且參數(shù)量顯著降低,壓縮了模型提升了性能,這就是知識蒸餾。<Distilling the Knowledge in a Neural Network>這篇論文首次提出了知識蒸餾的概念,核心思想就是訓(xùn)練一個(gè)復(fù)雜模型,把這個(gè)復(fù)雜模型的輸出和有l(wèi)abel的數(shù)據(jù)一并喂給了小網(wǎng)絡(luò),所以知識蒸餾一定會有個(gè)復(fù)雜的大模型(teacher model)和一個(gè)小模型(student model)。
為什么要蒸餾?
模型越來越深,網(wǎng)絡(luò)越來越大,參數(shù)越來越多,效果越來越好,但是計(jì)算復(fù)雜度呢?一并上升,蒸餾就是個(gè)特別好的方法,用于壓縮模型的大小。
- 提升模型準(zhǔn)確率:如果你不滿意現(xiàn)有小模型的效果,可以訓(xùn)練一個(gè)高度復(fù)雜效果極佳的大模型(teacher model),然后用它指導(dǎo)小模型達(dá)到你滿意的效果。
- 降低模型延遲,壓縮網(wǎng)絡(luò)參數(shù):網(wǎng)絡(luò)延遲大?像是bert這種大模型,是否可以用一個(gè)一層,減少head數(shù)的簡單網(wǎng)絡(luò)去學(xué)習(xí)bert呢,這樣不僅提升了簡單網(wǎng)絡(luò)的準(zhǔn)確率,也實(shí)現(xiàn)了延遲的降低。
- 遷移學(xué)習(xí):比方說一個(gè)老師知道分辨貓狗,另一個(gè)老師知道分辨香蕉蘋果,那學(xué)生從這兩個(gè)老師學(xué)習(xí)就能同時(shí)分辨貓狗和香蕉蘋果。
順便回顧下之前探討過的模型壓縮5種方法:
- Model pruning
- Quantification
- Knowledge distillation
- Parameter sharing
- Parameter matrix approximation
理想情況下,我們是希望同樣一份訓(xùn)練數(shù)據(jù),無論是大模型還是小模型,他們收斂的空間重合度很高,但實(shí)際情況由于大模型搜索空間較大,小模型較小,他們收斂的重合度就較低,知識蒸餾能提升他們之間的重合度使得小模型有更好的泛化能力。
知識蒸餾最基礎(chǔ)的框架:
使用Teacher-Student model,用一個(gè)非常大而復(fù)雜的老師模型,輔助學(xué)生模型訓(xùn)練。老師模型巨大復(fù)雜,因此不用于在線,學(xué)生模型部署在線上,靈活小巧易于部署。知識蒸餾可以簡單的分為兩個(gè)主要的方向:target-based蒸餾,feature-based蒸餾。
Target distillation-Logits method
上文提到的那篇最經(jīng)典的論文就是該方法一個(gè)很好的例子。這篇論文解決的是一個(gè)分類問題,既然是分類問題模型就會有個(gè)softmax層,該層輸出值直接就是每個(gè)類別的概率,在知識蒸餾中,因?yàn)槲覀冇袀€(gè)很好的老師模型,一個(gè)最直接的方法就是讓學(xué)生模型去擬合老師模型輸出的每個(gè)類別的概率,也就是我們常說的"Soft-target"。
Hard-target and Soft-target
模型要能訓(xùn)練,必須定義loss函數(shù),目標(biāo)就是讓預(yù)測值更接近真實(shí)值,真實(shí)值就是Hard-target,loss函數(shù)會使得偏差越來越小。在知識蒸餾中,直接學(xué)習(xí)每個(gè)類別的概率(老師模型預(yù)估的)就是soft-target。
Hard-target:類似one-hot的label,比如二分類,正例是1,負(fù)例是0。
Soft-target:老師模型softmax層輸出的概率分布,概率最大的就是正類別。
知識蒸餾使得老師模型的soft-target去指導(dǎo)用hard-target學(xué)習(xí)的學(xué)生模型,為什么是有效的呢?因?yàn)槔蠋熌P洼敵龅膕oftmax層攜帶的信息要遠(yuǎn)多于hard-target,老師模型給學(xué)生模型不僅提供了正例的信息,也提供了負(fù)例的概率,所以學(xué)生模型可以學(xué)到更多hard-target學(xué)不到的東西。
知識蒸餾具體方法:
神經(jīng)網(wǎng)絡(luò)用softmax層去計(jì)算各類的概率:
但是直接使用softmax的輸出作為soft-target會有其他問題,當(dāng)softmax輸出的概率分布的熵相對較小時(shí),負(fù)類別的label就接近0,對loss函數(shù)的共享就非常小,小到可以忽略。所以可以新增個(gè)變量"temperature",用下式去計(jì)算softmax函數(shù):
當(dāng)T是1,就是以前的softmax模型,當(dāng)T非常大,那輸出的概率會變的非常平滑,會有很大的熵,模型就會更加關(guān)注負(fù)類別。
具體蒸餾流程如下:
1.訓(xùn)練老師模型;
2.使用個(gè)較高的溫度去構(gòu)建Soft-target;
3.同時(shí)使用較高溫度的Soft-target和T=1的Soft-target去訓(xùn)練學(xué)生模型;
4.把T改為1在學(xué)生模型上做預(yù)估。
老師模型的訓(xùn)練過程非常簡單。學(xué)生模型的目標(biāo)函數(shù)可以同時(shí)使用兩個(gè)loss,一個(gè)是蒸餾loss,另一個(gè)是本身的loss,用權(quán)重控制,如下式所示:
老師和學(xué)生使用相同的溫度T,vi適合zi指softmax輸出的logits。L_hard用的就是溫度1。
L_hard的重要性不言而喻,老師也可能會教錯(cuò)!使用L_hard能避免老師的錯(cuò)誤傳遞給學(xué)生。L_soft和L_hard之前的權(quán)重也比較重要,實(shí)驗(yàn)表明L_hard權(quán)重較小往往帶來更好的效果,因?yàn)長_soft的梯度貢獻(xiàn)大約是1/T^2,所以L_soft最好乘上一個(gè)T^2去確保兩個(gè)loss的梯度貢獻(xiàn)等同。
一種特殊形式的蒸餾方式:Direct Matching Logits
直接使用softmax層產(chǎn)出的logits作為soft-target,目標(biāo)函數(shù)直接使用均方誤差,如下所示:
和傳統(tǒng)蒸餾方法相比,T趨向于無窮大時(shí),直接擬合logits和擬合概率是等同的(證明略),所以這是一種特殊形式的蒸餾方式。
關(guān)于溫度:
一個(gè)較高的溫度,往往能蒸餾出更多知識,但是怎么去調(diào)節(jié)溫度呢?
- 最原始的softmax函數(shù)就是T=1,當(dāng)T < 1,概率分布更"陡",當(dāng)T->0,輸出值就變成了Hard-target,當(dāng)T > 1,概率分布就會更平滑。
- 當(dāng)T變大,概率分布熵會變大,當(dāng)T趨于無窮,softamx結(jié)果就均勻分布了。
- 不管T是多少,Soft-target會攜帶更多具有傾向性的信息。
T的變化程度決定了學(xué)生模型有多少attention在負(fù)類別上,當(dāng)溫度很低,模型就不太關(guān)注負(fù)類別,特別是那些小于均值的負(fù)類別,當(dāng)溫度很高,模型就更多的關(guān)注負(fù)類別。事實(shí)上負(fù)類別攜帶更多信息,特別是大于均值的負(fù)類別。因此選對溫度很重要,需要更多實(shí)驗(yàn)去選擇。T的選擇和學(xué)生模型的大小關(guān)系也很大,當(dāng)學(xué)生模型相對較小,一個(gè)較小的T就足夠了,因?yàn)閷W(xué)生模型沒有能力學(xué)習(xí)老師模型全部的知識,一些負(fù)類別信息就可以忽略。
除此以外,還有很多特別的蒸餾思想,如intermediate based蒸餾,如下圖所示,蒸餾的不僅僅是softmax層,連中間層一并蒸餾。
關(guān)于"知識蒸餾",你想知道的都在這里!總結(jié)
以上是生活随笔為你收集整理的关于知识蒸馏,你想知道的都在这里!的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 时间序列里面最强特征之一
- 下一篇: 炼丹秘术:给Embedding插上翅膀