不仅搞定“梯度消失”,还让CNN更具泛化性:港科大开源深度神经网络训练新方法
原文鏈接:不僅搞定“梯度消失”,還讓CNN更具泛化性:港科大開源深度神經網絡訓練新方法
paper: https://arxiv.org/abs/2003.10739
code: https://github.com/d-li14/DHM
該文是港科大李鐸、陳啟峰提出的一種優化模型訓練、提升模型泛化性能與模型精度的方法,相比之前Deeply-Supervised Networks方式,所提方法可以進一步提升模型的性能。值得一讀。
Abstract
時間見證了深度神經網絡的深度的迅速提升(自LeNet的5層到ResNet的上千層),但尾端監督的訓練方式仍是當前主流方法。之前有學者提出采用深度監督(Deeply-supervised,DSN)方式緩解深度網絡的訓練難度問題,但是它不可避免的會影響深度網絡的分層特征表達能力,同時會導致前后矛盾的優化目標。
作者提出一種動態分層模仿機制(Dynamic Hierarchical Mimicking,一種廣義特征學習機制)加速CNN訓練同時使其具有更強的泛化性能。所提方法部分受DSN啟發,對給定神經網絡的中間特征進行巧妙的設置邊界分支(side branches)。每個分支可以動態的出現在主分支的特定位置,它不僅可以保留骨干網絡的特征表達能力,同時還可以研其通路產生更多樣性的特征表達。與此同時,作者提出采用概率預測匹配損失進一步提升多分支的多級交互影響,它可以確保優化過程的魯棒性,同時具有更好的泛化性能。
最后作者在分類與實例識別任務上驗證了所提方法的性能,均可取得一致性的性能提升。
Method
該部分內容首先簡單介紹一下深度監督及存在的問題,最后給出所提方法。由于該部分內容公式較多,文字較多,故這里僅進行粗略的介紹,在后面對進行一些個人理解分析。
Analysis of Deep Supervision
對于深度網絡而言,其優化目標可以描述為:
argminWmLm(Wm;D)+γR(Wm)argmin_{W_m} \mathcal{L}_m(W_m; \mathcal{D}) + \gamma \mathcal{R}(W_m) argminWm??Lm?(Wm?;D)+γR(Wm?)
其中Lm(Wm;D)\mathcal{L}_m(W_m; \mathcal{D})Lm?(Wm?;D)表示待優化的整體損失函數,而R(Wm)\mathcal{R}(W_m)R(Wm?)表示針對參數添加的一些正則化處理。對于圖像分類而言,上述損失函數可以定義為:
Lm(Wm;D)=?1N∑i=1Nfm(Wm;xi)(yi)\mathcal{L}_m(W_m; \mathcal{D})=-\frac{1}{N} \sum_{i=1}^{N} f_m(W_m;x_i)^{(y_i)} Lm?(Wm?;D)=?N1?i=1∑N?fm?(Wm?;xi?)(yi?)
另,由于正則項僅與參數有關,而與網絡結構無關,故在后續介紹中對上述公式進行簡化,得到:
argminWmLm(Wm;D)argmin_{W_m} \mathcal{L}_m(W_m; \mathcal{D}) argminWm??Lm?(Wm?;D)
一般而言,在圖像分類任務中,往往僅在網絡的head后進行損失計算。這種處理方式對于比較淺的網絡而言并沒有什么問題,但是對于極深網絡而言則會由于梯度反向傳播過程中的“梯度消失”問題導致網絡收斂緩慢或者不收斂或收斂到局部最優。
針對上述現象,Deeply-Supervised Nets提出了多級監督方式進行訓練。該訓練方式的優化目標函數可以描述為:
argminWm,WsL(Wm;D)+Ls(Wm,Ws;D)argmin_{W_m,\mathcal{W}_s} \mathcal{L}(W_m; \mathcal{D}) + \mathcal{L}_s(W_m, \mathcal{W}_s; \mathcal{D}) argminWm?,Ws??L(Wm?;D)+Ls?(Wm?,Ws?;D)
其中Ls\mathcal{L}_sLs?表示額外監督信息的損失。注:GoogLeNet一文采用的訓練方式就是它的一種特例。
通過上述上述訓練方式,中間層不僅可以從頂層損失獲取梯度信息,還可以從分支損失獲取提取信息,這使得其具有緩解“梯度消失”,加速網絡收斂的功能。
然而,直接在中間層添加額外的監督信息的方式在訓練極深網絡時可能會導致模型性能下降。眾所周知,深度網絡具有極強的分層特征表達能力,其特征會隨網絡深度而變化(底層特征聚焦邊緣特征而缺乏語義信息,而高層特征則聚焦于語義信息)。在底層添加強監督信息會導致深度網絡的上述特征表達方式被破壞,進而導致模型的性能下降。這從某種程度上解釋了為何上述監督方式對模型的性能提升比較小(大概在0.5%左右,甚至無提升)。
Dynamic Hierarchical Mimicking
作者重新對上述優化目標進行了分析并給出猜測:“最本質的原因在于損失函數中相加的兩塊損失優化目標不一致”。以分類為例,盡管兩者均意在優化交叉熵損失,但兩者在中間層的優化方向是不一致的,存在矛盾點,進而導致對最終模型性能產生負面影響。
針對上述問題,作者提出一種新穎的知識匹配損失用于正則化訓練過程,并使得不同損失對中間層的優化目標相一致,從而確保了模型的魯棒性與泛化性能。
所提方法的優化目標函數可以描述如下公式,其示意圖見上圖。
argminWm,WsL(Wm;D)+Ls(WΦ~;IΦ,D)+Lk(WΦ~;IΦ,D)argmin_{W_m, \mathcal{W}_s} \mathcal{L}(W_m;\mathcal{D}) + \mathcal{L}_s(\mathcal{W}_{\tilde{\Phi}};I_{\Phi},\mathcal{D}) + \mathcal{L}_k(\mathcal{W}_{\tilde{\Phi}};I_{\Phi}, \mathcal{D}) argminWm?,Ws??L(Wm?;D)+Ls?(WΦ~?;IΦ?,D)+Lk?(WΦ~?;IΦ?,D)
其中比較關鍵在于第三項的引入,也就是所提到的知識匹配損失。注:由于全文公式太多,本人只是相對粗略的看來一遍,沒有過于深度去研究。應該不會影響對其的認知,見后續的對比分析。
Experiments
為驗證所提方法的有效性,作者在多個數據集(Cifar,ImageNet,Market1501等)上的機型了實驗對比分析。
首先,給出了CIFAR-100數據集上所提方法與DSL的性能對比,見下圖。盡管DSL可以提升模型的性能,但提提升比較少,而作者所提DHM可以得到更高的性能提升。該實驗證實了所提方法的有效性。
然后,作者給出了ImageNet數據集上的性能對比,見下圖。可以得到與前面類似的結論,但同時可以看到:對于極深網絡(如ResNe152),DSL的性能提升非常有限,而所提方法仍能極大的提升模型的性能超1%。
其次,作者給出了Market1501數據集上的性能對比,見下圖。結論同前,不再贅述。
最后,作者還提供了其實驗過程中的網絡架構,這里僅提供一個參考模型(MobileNet)作為示例以及分析說明。除了MobileNet外,作者還提供了DenseNet、ResNet、WRN等實驗模型。
Discusion
實事求是的說,本人在看到最后的網絡結構和代碼之前是沒看明白這篇論文該怎么應用的。只是大概了解DSL破壞了深度網絡的分層特征表達能力,針對該問題而提出的解決方案。
看了論文和代碼后,基本上明白了作者是怎么做的。就一點:既然DSL破壞了深度網絡的分層特征表達能力,那么就想辦法去補償以不同損失反向傳播到中間層與底層時優化方向是一致的。那么該怎么去補償呢?下圖給出了圖示,中間主干分支表示預定義好的網絡結構,左右兩個分支表示作者補償的結構,通過這樣的方式可以確保主損失與右分支損失傳播到layer3的優化方向一致,主損失與做分支損失傳播到layer2的優化方向一致。當然圖中兩個顏色layer3表示這是不同的處理過程,分支的處理過程肯定要比主分支的計算量小,否則豈不是加大了訓練難度?
我想,看到這里大家基本上都明白了DHM這篇論文所要表達的思想了。接下來,將嘗試將其與其他類似的方法進行一下對比分析。首先給出傳統訓練方式、DSL訓練方式與DHM的對比圖(注:圖中暗紅色區域表示損失計算,具體怎么計算不詳述)。
上圖給出了常規訓練過程、DSL訓練過程以及DHM的訓練成果對比。常規訓練過程僅在head部分有一個損失;而DSN(即DSL)則有多個損失,不同的損失回傳的速度時不一樣的,比如左分支損失直接傳給了layer2,這明顯快于中間的主損失,這是緩解“梯度消失”的原因所在;DHM類似于DSL具有多個損失,但同時為防止不同損失對中間層優化方向的不一致,而添加了額外的輔助層,用于模擬深度網絡的分層特征表達。
那么DHM是如何緩解“梯度消失”現象的呢?個人認為,它有兩種方式:(1) ResNet與DenseNet中的緩解“梯度消失”的方式,這與網路結構有關;(2)分支層數少于主干層數,一定程度上緩解了“梯度消失”。
最后,再補上一個與DHM極為相似的方法DML,兩者的流程圖如下所示。論文原文確實提到了DML方法,但并未與之進行對比。從圖示可以看到兩者還是比較相似的,盡管DML初衷是兩個網絡采用知識蒸餾的方式進行訓練,而DHM則是針對DSL存在的缺陷進行的改進。
私認為DHM是DML的特例(注:僅僅從上述圖示出發),有這么三點原因:
- 損失函數方面,以圖像分類為例,DML與DHM均采用交叉熵損失+KL散度計算不同分支損失;
- 分支數方面:盡管DML原文是借鑒識蒸餾方式,但其分支可以不止兩個,比如擴展到三個呢,四個呢?這兩種方式是不是就一樣了呢?
- 網路結構方面:盡管DML提到的是兩個網絡,但是兩個網絡如果共享stem+layer1+layer2部分呢?從這個角度來看,DHM與DML殊途同歸了。
做完上述記錄后,本人厚著臉皮去騷擾了一下李鐸大神,請教了一下。經允許,現將作者的理解摘錄如下:
DSL存在的問題:(1) 特征逐級提取問題,如果像上述圖中googlenet/dsn那樣把head直接接在中間層立刻再接classifier,那么強制要求layer2、layer3、layer4都提取high-level語意特征,這和一般網絡里layer2、layer3可能還在提取更low-level的特征相違背;(2) 不同分支的gradient都會回傳到shared的主支上,如果這些gradient相互沖突甚至抵消,對于整個網絡的優化是產生負面影響的。
DHM的解決方案:(1)第一個問題通過圖中的分支網絡結構的改進來解決;(2)第二個問題則是通過KL散度損失隱式約束梯度來解決。
OK,關于DHM的介紹,全文到底結束!碼字不易,思考更不易,還請給個贊。
Reference
關注極市平臺公眾號(ID:extrememart),獲取計算機視覺前沿資訊/技術干貨/招聘面經等
總結
以上是生活随笔為你收集整理的不仅搞定“梯度消失”,还让CNN更具泛化性:港科大开源深度神经网络训练新方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 大盘点|卷积神经网络必读的 100 篇经
- 下一篇: 重磅开源!目标检测新网络 Detecto