ICCV 2019 开源论文 | 基于元学习和AutoML的模型压缩新方法
這篇文章來自于曠視。曠視內(nèi)部有一個(gè)基礎(chǔ)模型組,孫劍老師也是很看好 NAS 相關(guān)的技術(shù),相信這篇文章無論從學(xué)術(shù)上還是工程落地上都有可以讓人借鑒的地方。回到文章本身,模型剪枝算法能夠減少模型計(jì)算量,實(shí)現(xiàn)模型壓縮和加速的目的,但是模型剪枝過程中確定剪枝比例等參數(shù)的過程實(shí)在讓人頭痛。
這篇文章提出了 PruningNet 的概念,自動(dòng)為剪枝后的模型生成權(quán)重,從而繞過了費(fèi)時(shí)的 retrain 步驟。并且能夠和進(jìn)化算法等搜索方法結(jié)合,通過搜索編碼 network 的 coding vector,自動(dòng)地根據(jù)所給約束搜索剪枝后的網(wǎng)絡(luò)結(jié)構(gòu)。和 AutoML 技術(shù)相比,這種方法并不是從頭搜索,而是從已有的大模型出發(fā),從而縮小了搜索空間,節(jié)省了搜索算力和時(shí)間。
個(gè)人覺得這種剪枝和 NAS 結(jié)合的方法,應(yīng)該會(huì)在以后吸引越來越多人的注意。這篇文章的代碼已經(jīng)開源在了 Github:
https://github.com/liuzechun/MetaPruningMotivation
模型剪枝是一種能夠減少模型大小和計(jì)算量的方法。模型剪枝一般可以分為三個(gè)步驟:
訓(xùn)練一個(gè)參數(shù)量較多的大網(wǎng)絡(luò)
將不重要的權(quán)重參數(shù)剪掉
剪枝后的小網(wǎng)絡(luò)做 fine tune
其中第二步是模型剪枝中的關(guān)鍵。有很多 paper 圍繞“怎么判斷權(quán)重是否重要”以及“如何剪枝”等問題進(jìn)行討論。困擾模型剪枝落地的一個(gè)問題就是剪枝比例的確定。
傳統(tǒng)的剪枝方法常常需要人工 layer by layer 地去確定每層的剪枝比例,然后進(jìn)行 fine tune,用起來很耗時(shí),而且很不方便。不過最近的?Rethinking the Value of Network Pruning?[1] 指出,剪枝后的權(quán)重并不重要,對(duì)于 channel pruning 來說,更重要的是找到剪枝后的網(wǎng)絡(luò)結(jié)構(gòu),具體來說就是每層留下的 channel 數(shù)量。
受這個(gè)發(fā)現(xiàn)啟發(fā),文章提出可以用一個(gè) PruningNet,對(duì)于給定的剪枝網(wǎng)絡(luò),自動(dòng)生成 weight,無需進(jìn)行 retrain,然后評(píng)測剪枝網(wǎng)絡(luò)在驗(yàn)證集上的性能,從而選出最優(yōu)的網(wǎng)絡(luò)結(jié)構(gòu)。
具體來說,PruningNet 的輸入是剪枝后的網(wǎng)絡(luò)結(jié)構(gòu),必須首先對(duì)網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行編碼,轉(zhuǎn)換為 coding vector。這里可以直接用剪枝后網(wǎng)絡(luò)每層的 channel 數(shù)來編碼。在搜索剪枝網(wǎng)絡(luò)的時(shí)候,我們可以嘗試各種 coding vector,用 PruningNet 生成剪枝后的網(wǎng)絡(luò)權(quán)重。網(wǎng)絡(luò)結(jié)構(gòu)和權(quán)重都有了,就可以去評(píng)測網(wǎng)絡(luò)的性能。進(jìn)而用進(jìn)化算法搜索最優(yōu)的 coding vector,也就是最優(yōu)的剪枝結(jié)構(gòu)。在用進(jìn)化算法搜索的時(shí)候,可以使用自定義的目標(biāo)函數(shù),包括將網(wǎng)絡(luò)的 accuracy,latency,FLOPS 等考慮進(jìn)來。
PruningNet
從上一小節(jié)已經(jīng)可以知道,PruningNet 是整個(gè)算法的關(guān)鍵。那么怎么才能找到這樣一個(gè)“神奇網(wǎng)絡(luò)”呢?
先做一下符號(hào)約定,使用 ci?表示剪枝之后第 i 層的 channel 數(shù)量, l 為網(wǎng)絡(luò)的層數(shù), W 表示剪枝后網(wǎng)絡(luò)的權(quán)重。那么 PruningNet 的輸入輸出如下所示:
訓(xùn)練
先結(jié)合下圖看一下 forward 部分。PruningNet 是由 l 個(gè) PruningBlock 組成的,每個(gè) PruningBlock 是一個(gè)兩層的 MLP。
首先看圖 b,編碼著網(wǎng)絡(luò)結(jié)構(gòu)信息的 coding vector 輸入到當(dāng)前 block 后,輸出經(jīng)過 Reshape,成了一個(gè) Weight Matrix。注意哦,這里的 WeightMatrix 是固定大小的(也就是未剪枝的原始 Weight shape 大小),和剪枝網(wǎng)絡(luò)結(jié)構(gòu)無關(guān)。
再看圖 a,因?yàn)橐獙?duì)網(wǎng)絡(luò)進(jìn)行剪枝,所以 WeightMatrix 要進(jìn)行 Crop。對(duì)應(yīng)到圖 b,可以看到,Crop 是在兩個(gè)維度上進(jìn)行的。首先,由于上一層也進(jìn)行了剪枝,所以 input channel 數(shù)變少了;其次,由于當(dāng)前層進(jìn)行了剪枝,所以 output channel 數(shù)變少了。這樣經(jīng)過 Crop,就生成了剪枝后的網(wǎng)絡(luò) weight。我們再輸入一個(gè) mini batch 的訓(xùn)練圖片,就可以得到剪枝后的網(wǎng)絡(luò)的 loss。
在 backward 部分,我們不更新剪枝后網(wǎng)絡(luò)的權(quán)重,而是更新 PruningNet 的權(quán)重。由于上面的操作都是可微分的,所以直接用鏈?zhǔn)椒▌t傳過去就行。如果你使用 PyTorch 等支持自動(dòng)微分的框架,這是很容易的。
下圖所示是訓(xùn)練過程的整個(gè) PruningNet(左側(cè))和剪枝后網(wǎng)絡(luò)(右側(cè),即 PrunedNet)。訓(xùn)練過程中的 coding vector 在狀態(tài)空間里隨機(jī)采樣,隨機(jī)選取每層的 channel 數(shù)量。
PS:和原始論文相比,下圖和上圖順序是顛倒的。這里從底向上介紹了 PruningNet 的訓(xùn)練,而論文則是自頂向下。
搜索
訓(xùn)練好 PruningNet 后,就可以用它來進(jìn)行搜索了!我們只需要輸入某個(gè) coding vector,PruningNet 就會(huì)為我們生成對(duì)應(yīng)每層的 WeightMatrix。別忘了 coding vector 是編碼的網(wǎng)絡(luò)結(jié)構(gòu),現(xiàn)在又有了 weight,我們就可以在驗(yàn)證集上測試網(wǎng)絡(luò)的性能了。進(jìn)而,可以使用進(jìn)化算法等優(yōu)化方法去搜索最優(yōu)的 coding vector。當(dāng)我們得到了最優(yōu)結(jié)構(gòu)的剪枝網(wǎng)絡(luò)后,再 from scratch 地訓(xùn)練它。
進(jìn)化算法這里不再贅述,很多優(yōu)化的書中包括網(wǎng)上都有資料。這里把整個(gè)算法流程貼出來:
實(shí)驗(yàn)
作者在 ImageNet 上用 MobileNet 和 ResNet 進(jìn)行了實(shí)驗(yàn)。訓(xùn)練 PruningNet 用了 1/4 的原模型的 epochs。數(shù)據(jù)增強(qiáng)使用常見的標(biāo)準(zhǔn)流程,輸入 image 大小為 224×224。
將原始 ImageNet 的訓(xùn)練集做分割,每個(gè)類別選 50 張組成 sub-validation(共計(jì) 50000),其余作為 sub-training。在訓(xùn)練時(shí),我們使用 sub-training 訓(xùn)練 PruningNet。在搜索時(shí),使用 sub-validation 評(píng)估剪枝網(wǎng)絡(luò)的性能。不過,還要注意,在搜索時(shí),使用 20000 張 sub-training 中的圖片重新計(jì)算 BatchNorm layer 中的 running mean 和 running variance。
shortcut 剪枝
在進(jìn)行模型剪枝時(shí),一個(gè)比較難處理的問題是 ResNet 中的 shortcut 結(jié)構(gòu)。因?yàn)樽詈笥幸粋€(gè) element-wise 的相加操作,必須保證兩路 feature map 是嚴(yán)格 shape 相同的,所以不能隨意剪枝,否則會(huì)造成 channel 不匹配。下面對(duì)幾種論文中用到的網(wǎng)絡(luò)結(jié)構(gòu)分別討論。
MobileNet-v1
MobileNet-v1 是沒有 shortcut 結(jié)構(gòu)的。我們?yōu)槊總€(gè) conv layer 都配上相應(yīng)的 PruningBlock——一個(gè)兩層的 MLP。PruningNet 的輸入 coding vector 中的元素是剪枝后每層的 channel 數(shù)量。而輸入第 i 個(gè) PruningBlock 的是一個(gè) 2D vector,由歸一化的第 i-1 層和第 i 層的剪枝比例構(gòu)成。這部分可以結(jié)合代碼來看:
https://github.com/liuzechun/MetaPruning/blob/master/mobilenetv1/training/mobilenet_v1.py#L15
注意第 1 個(gè) conv layer 的輸入是 1D vector,因?yàn)樗堑谝粋€(gè)被剪枝的 layer。在訓(xùn)練時(shí),coding vector 的搜索空間被以一定步長劃分為 grid,采樣就是在這些格點(diǎn)上進(jìn)行的。
MobileNet-v2
MobileNet-v2 引入了類似 ResNet 的 shortcut 結(jié)構(gòu),這種 resnet block 必須統(tǒng)一看待。具體來說,對(duì)于沒有在 resnet block 中的conv,處理方法如 MobileNet-v1。對(duì)每個(gè) resnet block,配上一個(gè)相應(yīng)的 PruningBlock。由于每個(gè) resnet block 中只有一個(gè)中間層(3×3 的 conv),所以輸出第 i 個(gè) PruningBlock 的是一個(gè) 3D vector,由歸一化的第 i-1 個(gè) resnet block,第 i 個(gè) resnet block 和中間 conv 層的剪枝比例構(gòu)成。其他設(shè)置和 MobileNet-v1 相同。這里可以結(jié)合代碼來看:
https://github.com/liuzechun/MetaPruning/blob/master/mobilenetv2/training/mobilenet_v2.py#L109
ResNet
處理方法如 MobileNet-v2 所示。可以結(jié)合代碼來看:
https://github.com/liuzechun/MetaPruning/blob/master/resnet/training/resnet.py#L75
實(shí)驗(yàn)結(jié)果
在相近 FLOPS 情況下,和 MobileNet 論文中改變 ratio 參數(shù)得到的模型比較,MetaPruning 得到的模型 accuracy 更高。尤其是壓縮比例更大時(shí),該方法更有優(yōu)勢。
和其他剪枝方法(如?AMC [2])等方法比較,該方法也得到了 SOTA 的結(jié)果。MetaPruning 方法能夠以一種統(tǒng)一的方法處理 ResNet 中的 shortcut 結(jié)構(gòu),并且不需要人工調(diào)整太多的參數(shù)。
上面的比較都是基于理論 FLOPS,現(xiàn)在更多人在關(guān)注網(wǎng)絡(luò)在實(shí)際硬件上的 latency 怎么樣。文章對(duì)此也進(jìn)行了討論。如何測試網(wǎng)絡(luò)的 latency?
當(dāng)然可以每個(gè)網(wǎng)絡(luò)都實(shí)際跑一下,不過有些麻煩。基于每個(gè) layer 的 inference 時(shí)間是互相獨(dú)立的這個(gè)假設(shè),作者首先構(gòu)造了各個(gè) layer inference latency 的查找表(參見論文?Fbnet: Hardware-aware efficient convnet design via differentiable neural architecture search [3]),以此來估計(jì)實(shí)際網(wǎng)絡(luò)的 latency。作者這里和 MobileNet baseline 做了比較,結(jié)果也證明了該方法更優(yōu)。
PruningNet 結(jié)果分析
此外,作者還對(duì) PruningNet 的預(yù)測結(jié)果進(jìn)行可視化,試圖找出一些可解釋性,并找出剪枝參數(shù)的一些規(guī)律。
down-sampling 的部分 PruningNet 傾向于保留更多的 channel,如 MobileNet-v2 block 中間的那個(gè) conv;
優(yōu)先剪淺層 layer 的 channel,FLOPS 約束太強(qiáng)剪深層的 channel,但可能會(huì)造成網(wǎng)絡(luò) accuracy 下降比較多。
總結(jié)
這篇文章從“剪枝后的 weight 作用不大”的現(xiàn)象出發(fā),將剪枝和 NAS 結(jié)合,提出了 PruningNet 為剪枝后的網(wǎng)絡(luò)預(yù)測 weight,避免了網(wǎng)絡(luò)的 retrain,從而可以快速衡量剪枝網(wǎng)絡(luò)的性能。并在編碼網(wǎng)絡(luò)信息的 coding vector 狀態(tài)空間進(jìn)行搜索,找到給定約束條件下的最優(yōu)網(wǎng)絡(luò)結(jié)構(gòu),在 ImageNet 數(shù)據(jù)集和 ResNet/MobileNet-v1/v2 上取得了比之前剪枝算法更好的效果。總結(jié)
隨著深度神經(jīng)網(wǎng)絡(luò)模型在各個(gè)場景下的落地,模型的壓縮和加速越來越受到大家的重視,剪枝是其中的重要方法。傳統(tǒng)的剪枝算法人工確定較多的參數(shù),所以很多文章開始考慮端到端的剪枝。
這篇論文把剪枝算法和 NAS 結(jié)合,取兩者之長,用待剪枝的模型縮小了搜索空間,用進(jìn)化算法自動(dòng)搜索最優(yōu)網(wǎng)絡(luò)結(jié)構(gòu)。使用 coding vector 編碼網(wǎng)絡(luò)結(jié)構(gòu),用一個(gè)很簡單的雙隱層感知機(jī)預(yù)測網(wǎng)絡(luò)權(quán)重,并提出了一種 shortcut 的處理方法,在 ImageNet 數(shù)據(jù)集和幾種常用網(wǎng)絡(luò)結(jié)構(gòu)上取得了不錯(cuò)的結(jié)果。文章提出的方法簡單易于操作,可以很方便地應(yīng)用到自己的業(yè)務(wù)場景中。相關(guān)代碼已經(jīng)開源在 Github 上。
相關(guān)鏈接
[1]?https://arxiv.org/abs/1810.05270
[2]?https://arxiv.org/abs/1802.03494
[3]?https://arxiv.org/abs/1812.03443
點(diǎn)擊以下標(biāo)題查看更多往期內(nèi)容:?
#投 稿 通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)習(xí)心得或技術(shù)干貨。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來。
??來稿標(biāo)準(zhǔn):
? 稿件確系個(gè)人原創(chuàng)作品,來稿需注明作者個(gè)人信息(姓名+學(xué)校/工作單位+學(xué)歷/職位+研究方向)?
? 如果文章并非首發(fā),請(qǐng)?jiān)谕陡鍟r(shí)提醒并附上所有已發(fā)布鏈接?
? PaperWeekly 默認(rèn)每篇文章都是首發(fā),均會(huì)添加“原創(chuàng)”標(biāo)志
? 投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請(qǐng)單獨(dú)在附件中發(fā)送?
? 請(qǐng)留下即時(shí)聯(lián)系方式(微信或手機(jī)),以便我們在編輯發(fā)布時(shí)和作者溝通
?
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個(gè)推薦、解讀、討論、報(bào)道人工智能前沿論文成果的學(xué)術(shù)平臺(tái)。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號(hào)后臺(tái)點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
▽ 點(diǎn)擊 |?閱讀原文?| 下載論文 & 源碼
總結(jié)
以上是生活随笔為你收集整理的ICCV 2019 开源论文 | 基于元学习和AutoML的模型压缩新方法的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 旱烟和香烟哪个危害大?
- 下一篇: 菅怎么读什么意思草菅人命