ICML 2021 | AlphaNet:基于α-散度的超网络训练方法
?作者 | 韓翔宇
學(xué)校 |?南昌大學(xué)
研究方向 | 神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索
概述
本文是 ICML?2021 收錄的 NAS(神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)搜索)領(lǐng)域的重磅論文,其作者是 AttentiveNas 的原作者,在 AttentiveNas 的基礎(chǔ)上,添加了 α- 散度損失函數(shù),在 ImageNet(NAS 方向)排行榜中取得了 SOTA 結(jié)果。
論文標(biāo)題:
AlphaNet: Improved Training of Supernets with Alpha-Divergence
論文鏈接:
https://arxiv.org/abs/2102.07954
代碼鏈接:
https://github.com/facebookresearch/AlphaNet
解決的問(wèn)題
盡管權(quán)值共享的 NAS 有較好的效果,因?yàn)檫@種方法構(gòu)建了一個(gè)超網(wǎng),利用子網(wǎng)來(lái)訓(xùn)練超網(wǎng)。然而,權(quán)重共享 NAS 的成功很大程度上依賴于將超網(wǎng)絡(luò)的知識(shí)提取到子網(wǎng)絡(luò)。如果廣泛使用蒸餾中的散度,例如 KL 散度,可能導(dǎo)致學(xué)生網(wǎng)絡(luò)高估或者低估教師網(wǎng)絡(luò)的不確定性,導(dǎo)致了子網(wǎng)的效果變差。論文提出了更廣義的 α- 散度來(lái)改進(jìn)超網(wǎng)訓(xùn)練。通過(guò)自適應(yīng)選擇 α- 散度,避免了高估或者低估教師模型不確定性。改進(jìn)之后的 AlphaNet 在 ImageNet top-1 的精度達(dá)到了 80%,且參數(shù)量只有 444M。
簡(jiǎn)介
傳統(tǒng)的 NAS 方法代價(jià)非常大,需要數(shù)百個(gè)網(wǎng)絡(luò)結(jié)構(gòu)從頭訓(xùn)練、驗(yàn)證效果。超網(wǎng)將所有候選體系結(jié)構(gòu)組裝成一個(gè)權(quán)重共享網(wǎng)絡(luò),每個(gè)網(wǎng)絡(luò)結(jié)構(gòu)對(duì)應(yīng)一個(gè)子網(wǎng)絡(luò)。通過(guò)同時(shí)訓(xùn)練子網(wǎng)和超網(wǎng),子網(wǎng)可以直接從超網(wǎng)中獲得的權(quán)重,用來(lái)重新訓(xùn)練和驗(yàn)證,而不需要單獨(dú)訓(xùn)練或微調(diào)每個(gè)結(jié)構(gòu),因此成本大大降低。為了穩(wěn)定超網(wǎng)訓(xùn)練和提高子網(wǎng)絡(luò)的性能,大家廣泛使用知識(shí)蒸餾的方式。知識(shí)蒸餾用超網(wǎng)中最大的子網(wǎng)預(yù)測(cè)的軟標(biāo)簽來(lái)監(jiān)督所有其他子網(wǎng),提取教師模型的知識(shí)來(lái)提高子網(wǎng)性能。
一般情況下,知識(shí)蒸餾使用 KL 散度衡量師生網(wǎng)絡(luò)之間的差異。但是如果學(xué)生網(wǎng)絡(luò)對(duì)教師網(wǎng)絡(luò)的某些部分覆蓋不完整,對(duì)學(xué)生網(wǎng)絡(luò)的懲罰大。因此,學(xué)生模型往往高估了教師模型的不確定性,不能準(zhǔn)確近似教師模型的正確預(yù)測(cè)。
為了解決這個(gè)問(wèn)題,論文提出用更廣義的 α- 散度代替 KL 散度,具體來(lái)講,通過(guò)自適應(yīng)控制散度度量中的 α,可以同時(shí)懲罰對(duì)教師網(wǎng)絡(luò)高估或者低估教師網(wǎng)絡(luò)不確定性的行為,鼓勵(lì)學(xué)生網(wǎng)絡(luò)更好地近似教師網(wǎng)絡(luò)。直接優(yōu)化 α 散度可能會(huì)受到梯度的高方差的影響,論文游提出了一種簡(jiǎn)單的梯度裁剪技術(shù),穩(wěn)定訓(xùn)練過(guò)程。通過(guò)提出的自適應(yīng) α- 散度,我們能夠訓(xùn)練高質(zhì)量的 alphanet,在 200 到 800 MFLOPs 范圍表現(xiàn)都達(dá)到 SOTA 效果。
關(guān)于知識(shí)蒸餾的相關(guān)知識(shí)
在權(quán)值共享的 NAS 中,知識(shí)蒸餾是重要的方法。假設(shè)超網(wǎng)有可訓(xùn)練的參數(shù) θ,訓(xùn)練超網(wǎng)的目的是學(xué)習(xí) θ,讓所有的子網(wǎng)都能同時(shí)得到優(yōu)化,達(dá)到較高的準(zhǔn)確率。
上圖描述了采用知識(shí)蒸餾的超網(wǎng)訓(xùn)練過(guò)程。在每個(gè)訓(xùn)練步驟中,給定一小批數(shù)據(jù),對(duì)超網(wǎng)和幾個(gè)子網(wǎng)絡(luò)采樣。當(dāng)超網(wǎng)使用真實(shí)標(biāo)簽訓(xùn)練時(shí),所有抽樣的子網(wǎng)絡(luò)都使用超網(wǎng)預(yù)測(cè)的軟標(biāo)簽進(jìn)行監(jiān)督訓(xùn)練。然后對(duì)所有采樣的網(wǎng)絡(luò)的梯度進(jìn)行融合,更新超網(wǎng)參數(shù)。在第 t 個(gè) step 時(shí),超參數(shù) θ 被更新為如下(其中 ε 是 step 的數(shù)量):
g 定義如下:其中 LD 是數(shù)據(jù)集 D 上超網(wǎng)的損失函數(shù)(交叉熵), 為權(quán)重系數(shù), 是超網(wǎng)和采樣出子網(wǎng)的 KL 散度, 和 表示:輸入 x 的超網(wǎng)和子網(wǎng) s 的輸出概率。
知識(shí)蒸餾中KL散度的局限性(α-散度解決的問(wèn)題)
KL 散度廣泛用于衡量教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)之間的差異性。但是 KL 散度的缺點(diǎn)在于,當(dāng)學(xué)生網(wǎng)絡(luò)高估了教師網(wǎng)絡(luò)的不確定性時(shí),不能充分懲罰學(xué)生網(wǎng)絡(luò)。我們首先列舉一下 KL 散度的公式:
▲ KL散度公式(公式2為f散度,即KL散度的一般形式)
我們用 P 代表教師網(wǎng)絡(luò)的某個(gè)輸出概率,用 Q 代表學(xué)生網(wǎng)絡(luò)的某個(gè)輸出概率。KL 散度有個(gè)避零性質(zhì),當(dāng) P=0 時(shí),我們看等式右半部分 log(P/Q),因?yàn)榉肿訛?0,我們知道 log 函數(shù)趨向于 0 的時(shí)候,值趨向于負(fù)無(wú)窮。無(wú)論分母 Q 怎么變大(也就是是學(xué)生網(wǎng)絡(luò)某個(gè)輸出概率變大),值始終是負(fù)無(wú)窮,對(duì)應(yīng) KL 散度低,也就是對(duì)學(xué)生網(wǎng)絡(luò)的 Q 懲罰小,所以學(xué)生網(wǎng)絡(luò)就可以肆無(wú)忌憚地預(yù)測(cè),即使預(yù)測(cè)錯(cuò)誤很離譜,也不會(huì)懲罰。
被如下圖所示,針對(duì)某個(gè)輸出,橙色的教師網(wǎng)絡(luò)主要預(yù)測(cè)的類別是 3(概率最大的是 3,所以輸出為 3),其次是類別 4,預(yù)測(cè)其他類別的概率小的可以忽略;再看綠色的學(xué)生網(wǎng)絡(luò),預(yù)測(cè)的最主要的類別是 4(概率最大的是 4,所以輸出為 4,輸出的類別就已經(jīng)和教師網(wǎng)絡(luò)有偏差了),而且學(xué)生網(wǎng)絡(luò)預(yù)測(cè)其他類別的概率也比教師網(wǎng)絡(luò)高,所以學(xué)生網(wǎng)絡(luò)的不確定性比教師網(wǎng)絡(luò)大,換句話說(shuō),學(xué)生網(wǎng)絡(luò)高估了教師網(wǎng)絡(luò)的不確定性。
反過(guò)來(lái)想,我們用 Q 代表教師網(wǎng)絡(luò)的某個(gè)輸出概率,用 P 代表學(xué)生網(wǎng)絡(luò)的某個(gè)輸出概率(和上個(gè)假設(shè)掉換個(gè)位置)。這次讓 Q=0,這時(shí)就要再講 KL 散度的另一個(gè)性質(zhì):零強(qiáng)制性。用這個(gè)例子簡(jiǎn)單來(lái)簡(jiǎn)單解釋,就是分母 Q 為 0,P 也要為 0。這很好理解,當(dāng)分母為 0 的時(shí)候,分子如果不為 0 就會(huì)報(bào)除 0 錯(cuò)誤,但是當(dāng)分子分母都為 0 時(shí),P/Q=1, log(P/Q)=log1=0。
說(shuō)完這個(gè)性質(zhì)以后,我們想一想,如果 Q 為 0,KL 散度的值為 0,我們希望損失越小越好,我們希望損失比 0 更小,也就是負(fù)數(shù)。因此最小化 KL 散度會(huì)鼓勵(lì) P 避免趨向 0(因?yàn)楫?dāng) P 為 0 的時(shí)候,Q 可能為 0,Q一旦為 0,損失函數(shù)就一定為 0,也就不能繼續(xù)變小了,我們希望損失函數(shù)越小越好,當(dāng)然 P 不愿意趨向 0)。
從論文的角度來(lái)說(shuō),也就是學(xué)生模型會(huì)避免低概率模式,也就是學(xué)生模型會(huì)傾向于較高的概率。如下圖所示,學(xué)生模型預(yù)測(cè)類別 2 的概率為 100%,幾乎沒(méi)有不確定性可言,而教師網(wǎng)絡(luò)除了預(yù)測(cè)類別最多的 2,其他類別也有概率分布,所以教師網(wǎng)絡(luò)的不確定性比學(xué)生網(wǎng)絡(luò)大,也就是學(xué)生網(wǎng)絡(luò)低估了教師網(wǎng)絡(luò)的不確定性。
使用α-散度訓(xùn)練的超網(wǎng)
為了解決上述 KL 散度的不確定性,論文提出了一個(gè)靈活的 α- 散度(α∈R 且 α≠0 且 α≠1):
和 代表每個(gè)類的離散分布,共有 個(gè)類別。當(dāng) 時(shí), 的極限就是 ,同樣, 是 時(shí) 的極限。 散度的關(guān)鍵點(diǎn)在于,可以通過(guò)選擇不同的 值來(lái)集中懲罰不同類型的差異(低估或高估)。
如圖所示,當(dāng)阿爾法為負(fù)值的時(shí)候,藍(lán)色線代表了學(xué)生網(wǎng)絡(luò) 高估了教師網(wǎng)絡(luò) 的不確定性的情況,這時(shí) 很大。紫色線表示學(xué)生網(wǎng)絡(luò) 低估了教師網(wǎng)絡(luò) 的不確定性的情況,此時(shí) 很小。當(dāng) 為正數(shù)的時(shí)候,情況正好相反。
為了同時(shí)緩解超網(wǎng)訓(xùn)練時(shí)的高估和低估的問(wèn)題,用一個(gè)正的 和一個(gè)負(fù)的 ,在知識(shí)蒸餾損失函數(shù)中使用 和 中最大那個(gè),也就是:
總的 KL 散度為:
進(jìn)一步改進(jìn)的穩(wěn)定的α-散度
人們傾向于將 和 設(shè)置為較大,以確保學(xué)生模型在低估或高估教師模型的不確定性時(shí)受到足夠的懲罰。但是直接通過(guò)增加 絕對(duì)值這種方法,會(huì)讓優(yōu)化變得困難,我們首先來(lái)看一下 散度的梯度:
不難看出,如果 很大,則 也可能變得很大,從而影響了訓(xùn)練的穩(wěn)定性。為了讓訓(xùn)練變得穩(wěn)定,我們把 的最大值限定為 (如果小于 則不變,最大不能超過(guò) ),重新定義梯度表達(dá)式為:
該梯度等價(jià)于 p 和 q 的 f 散度(f 散度是 KL 散度的一般形式,如果你認(rèn)真看了這篇文章,應(yīng)該會(huì)注意到 KL 散度的公式有兩個(gè),沒(méi)錯(cuò),第二個(gè)公式就是 f 散度),進(jìn)行梯度更新相當(dāng)于最小化有效散度。通過(guò)裁剪重要性權(quán)重的值,我們優(yōu)化的仍然是一個(gè)散度度量,但對(duì)基于梯度的優(yōu)化更友好。
具體實(shí)現(xiàn)
自適應(yīng)的 α- 散度的設(shè)置:α- 和 α+ 分別控制對(duì)過(guò)高估計(jì)和過(guò)低估計(jì)的懲罰幅度。并且,β 控制了教師模型和學(xué)生模型之間的密度比率范圍(也就是子網(wǎng)占了超網(wǎng)多大一部分)。通過(guò)實(shí)驗(yàn)發(fā)現(xiàn) AlphaNet 的方法在 α?,α+ 和 β 的選擇范圍內(nèi)表現(xiàn)相對(duì)穩(wěn)健。實(shí)驗(yàn)中選擇 α?=?1,α+=1 和 β=5.0 作為默認(rèn)值。作者在 Silimmable Network 和權(quán)值共享的 NAS 兩個(gè)層面分別做了實(shí)驗(yàn):
8.1 Slimmable Network部分
Slimmable 網(wǎng)絡(luò)是支持選擇多種通道寬度的超網(wǎng),其搜索空間包含不同寬度的網(wǎng)絡(luò)和所有其他參數(shù)(深度、卷積類型、kernal 大小)都是相同的。Slimmable 允許不同的設(shè)備或應(yīng)用程序根據(jù)設(shè)備上的資源限制,自適應(yīng)地調(diào)整模型寬度,以實(shí)現(xiàn)最佳精度與能效的權(quán)衡。
使用 Slimmable MobileNet v1 和 v2 測(cè)試,其中 v1 的寬度范圍是 [0.25,1],v2 的寬度是 [0.35,1]。在每次訓(xùn)練迭代中,分別對(duì) channel 寬度最大的最大子網(wǎng)絡(luò)、channel 寬度最小的最小子網(wǎng)絡(luò)和兩個(gè)隨機(jī)子網(wǎng)絡(luò)進(jìn)行采樣,累積梯度(這種方式被稱為三明治法則)。使用 gt 標(biāo)簽訓(xùn)練超網(wǎng),使用知識(shí)蒸餾訓(xùn)練采樣到的子網(wǎng),設(shè)置知識(shí)蒸餾中的系數(shù) γ=3,作為抽樣獲得的子網(wǎng)數(shù)量。
為了驗(yàn)證自適應(yīng) α- 散度的有效性,使用它替換 baseline 中的 KL散度。使用 SGD 優(yōu)化器訓(xùn)練 360epoch,動(dòng)量為 0.9,重量衰減為 10?5,dropout 為 0.2。我們使用余弦學(xué)習(xí)速率衰減,初始學(xué)習(xí)速率為 0.8,batch_size 為 2048,使用 16 塊 GPU。
在 ImageNet 上進(jìn)行訓(xùn)練,上表是訓(xùn)練得到的最佳精度,可以看到,無(wú)論在任何寬度上,自適應(yīng)的α. 散度效果均優(yōu)于 KL 散度(KL-KD)和不加 KL 散度(KD)。
8.2 權(quán)值共享的NAS部分
8.2.1 簡(jiǎn)介
大多數(shù)基于權(quán)值共享的 NAS 由以下兩個(gè)階段組成(基于強(qiáng)化學(xué)習(xí)的NAS也一樣):
階段 1:使用可微分權(quán)值共享或者看做一個(gè)黑盒的優(yōu)化
階段 2:從頭開(kāi)始訓(xùn)練深度神經(jīng)網(wǎng)絡(luò),以獲得最佳的準(zhǔn)確率和最終的效果
但是這類 NAS 有缺陷:如果需要不同的硬件約束條件,需要重新搜索。而且要求對(duì)所有的候選結(jié)果從頭訓(xùn)練,達(dá)到理想的準(zhǔn)確率。因此,顯著增加了 NAS 的搜索成本。
論文使用的權(quán)重共享的 NAS 是基于超網(wǎng)的權(quán)重共享 NAS,搜索過(guò)程如下:
階段1:聯(lián)合優(yōu)化搜索空間中的超網(wǎng)和所有可能被采樣的子網(wǎng),使所有可搜索網(wǎng)絡(luò)在訓(xùn)練結(jié)束時(shí)都能獲得較好的性能。
階段2:然后所有的子網(wǎng)絡(luò)同時(shí)被優(yōu)化。然后可以使用典型的搜索算法,比如進(jìn)化算法,來(lái)搜索感興趣的最佳模型。每個(gè)子網(wǎng)絡(luò)的模型權(quán)值直接從預(yù)訓(xùn)練的超網(wǎng)絡(luò)繼承而來(lái),無(wú)需再進(jìn)行再訓(xùn)練或微調(diào)。
與基于 RL 的 NAS 算法和可微 NAS 算法相比,基于超網(wǎng)絡(luò)的權(quán)重共享的優(yōu)勢(shì)主要有:
1. 只需要對(duì)超網(wǎng)進(jìn)行一次訓(xùn)練。搜索空間中定義的所有子網(wǎng)絡(luò)在第1階段完全優(yōu)化后即可使用。不需要再訓(xùn)練或微調(diào);
2. 在階段 1 中,所有不同模型規(guī)模的子網(wǎng)絡(luò)進(jìn)行聯(lián)合優(yōu)化,找到一組帕累托最優(yōu)模型,這一組模型天生支持各種需要考慮的資源。
注意,權(quán)重共享 NAS 的一個(gè)主要步驟是同時(shí)訓(xùn)練搜索空間中指定的所有子網(wǎng)絡(luò)收斂。與訓(xùn)練 Slimmable 神經(jīng)網(wǎng)絡(luò)類似,也是用基于 KL 散度的知識(shí)蒸餾方法,強(qiáng)制所有的子網(wǎng)絡(luò)使從超網(wǎng)中學(xué)習(xí)來(lái)實(shí)現(xiàn)的。
8.2.2 具體訓(xùn)練
為了簡(jiǎn)單起見(jiàn),我們使用均勻抽樣策略,使用和 Silimable 的三明治法則相同的方式,每次迭代訓(xùn)練四個(gè)網(wǎng)絡(luò)。使用 SGD 和余弦淬火策略衰減學(xué)習(xí)率,使用 AutoAugment 進(jìn)行數(shù)據(jù)增強(qiáng),設(shè)置數(shù)據(jù)的標(biāo)簽平滑為 0.1。
上表是論文采用的搜索空間,MBConv 是 mobileNet 的倒殘差模塊。使用 swish 激活函數(shù),channel width 表示塊的輸出維度。輸入的分辨率表示候選的分辨率,為了簡(jiǎn)化數(shù)據(jù)加載過(guò)程,預(yù)選選取固定大小 224*224,然后再使用雙三次插值將它們重新縮放到我們的目標(biāo)分辨率。
8.2.3 驗(yàn)證
我們比較通過(guò)知識(shí)蒸餾策略得到不同超網(wǎng),評(píng)價(jià)采用準(zhǔn)確率和 FLOPs 的帕累托最優(yōu),主要包括以下三點(diǎn):
1. 首先從超網(wǎng)中隨機(jī)抽取 512 個(gè)子網(wǎng),并估計(jì)它們?cè)?ImageNet 驗(yàn)證集上的準(zhǔn)確性。
2. 對(duì)性能最好的 128 個(gè)子網(wǎng)應(yīng)用交叉和隨機(jī)變異,將交叉規(guī)模和變異規(guī)模都固定為 128,共產(chǎn)生 256 個(gè)新的子網(wǎng)。然后我們?cè)u(píng)估了這些子網(wǎng)的性能。
3. 重復(fù)步驟 20 次步驟 2,得到 5376 個(gè)子網(wǎng)。
8.2.4 最終效果
上表可以看出, AlphaNet 相比于普通的 KL 散度和不使用 KL 散度的訓(xùn)練方法,在各個(gè)參數(shù)量階段的準(zhǔn)確率都得到了提升。
上圖分別是不用 KL 散度(棕色),使用 KL 散度(綠色),AlphaNet(紅色)的訓(xùn)練收斂曲線,可以看出,使用 AlphaNet 訓(xùn)練在前期收斂速度會(huì)慢于普通的 KL 散度,但是在后期收斂速度會(huì)超過(guò)普通的 KL 散度。
論文將 AlphaNet 和現(xiàn)有的 NAS 做對(duì)比,相比于 EfficientNet、MobileV3、MNasNet、BigNas 等模型都有較為明顯的提升,取得了 SOTA 的效果。
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)術(shù)熱點(diǎn)剖析、科研心得或競(jìng)賽經(jīng)驗(yàn)講解等。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來(lái)。
📝?稿件基本要求:
? 文章確系個(gè)人原創(chuàng)作品,未曾在公開(kāi)渠道發(fā)表,如為其他平臺(tái)已發(fā)表或待發(fā)表的文章,請(qǐng)明確標(biāo)注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發(fā)送,要求圖片清晰,無(wú)版權(quán)問(wèn)題
? PaperWeekly 尊重原作者署名權(quán),并將為每篇被采納的原創(chuàng)首發(fā)稿件,提供業(yè)內(nèi)具有競(jìng)爭(zhēng)力稿酬,具體依據(jù)文章閱讀量和文章質(zhì)量階梯制結(jié)算
📬?投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來(lái)稿請(qǐng)備注即時(shí)聯(lián)系方式(微信),以便我們?cè)诟寮x用的第一時(shí)間聯(lián)系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長(zhǎng)按添加PaperWeekly小編
🔍
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁(yè)搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
·
總結(jié)
以上是生活随笔為你收集整理的ICML 2021 | AlphaNet:基于α-散度的超网络训练方法的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: ghost 系统光盘 怎么拷贝硬盘 如何
- 下一篇: 联想怎么在bios里设置u盘启动项 &a