一文通俗讲解元学习(Meta-Learning)
?PaperWeekly 原創(chuàng) ·?作者 | 孫裕道
學校 | 北京郵電大學博士生
研究方向 | GAN圖像生成、情緒對抗樣本生成
元學習(meta-learning)是過去幾年最火爆的學習方法之一,各式各樣的 paper 都是基于元學習展開的。深度學習模型訓練模型特別吃計算硬件,尤其是人為調(diào)超參數(shù)時候,更需要大量的計算。另一個頭疼的問題是在某個任務(wù)下大量數(shù)據(jù)訓練的模型,切換到另一個任務(wù)后,模型就需要重新訓練,這樣非常耗時耗力。工業(yè)界財大氣粗有大量的 GPU 可以承擔起這樣的計算成本,但是學術(shù)界因為經(jīng)費有限經(jīng)不起這樣的消耗。元學習可以有效的緩解大量調(diào)參和任務(wù)切換模型重新訓練帶來的計算成本問題。
元學習介紹
元學習希望使得模型獲取一種學會學習調(diào)參的能力,使其可以在獲取已有知識的基礎(chǔ)上快速學習新的任務(wù)。機器學習是先人為調(diào)參,之后直接訓練特定任務(wù)下深度模型。元學習則是先通過其它的任務(wù)訓練出一個較好的超參數(shù),然后再對特定任務(wù)進行訓練。
在機器學習中,訓練單位是樣本數(shù)據(jù),通過數(shù)據(jù)來對模型進行優(yōu)化;數(shù)據(jù)可以分為訓練集、測試集和驗證集。在元學習中,訓練單位是任務(wù),一般有兩個任務(wù)分別是訓練任務(wù)(Train Tasks)亦稱跨任務(wù)(Across Tasks)和測試任務(wù)(Test Task)亦稱單任務(wù)(Within Task)。訓練任務(wù)要準備許多子任務(wù)來進行學習,目的是學習出一個較好的超參數(shù),測試任務(wù)是利用訓練任務(wù)學習出的超參數(shù)對特定任務(wù)進行訓練。訓練任務(wù)中的每個任務(wù)的數(shù)據(jù)分為 Support set 和 Query set;Test Task 中數(shù)據(jù)分為訓練集和測試集。
令 表示需要設(shè)置的超參數(shù), 表示神經(jīng)網(wǎng)絡(luò)待訓練的參數(shù)。元學習的目的就是讓函數(shù) 在訓練任務(wù)中自動訓練出 ,再利用 這個先驗知識在測試任務(wù)中訓練出特定任務(wù)下模型 中的參數(shù) ,如下所示的依賴關(guān)系:
當訓練一個神經(jīng)網(wǎng)絡(luò)的時候,具體一般步驟有,預(yù)處理數(shù)據(jù)集 ,選擇網(wǎng)絡(luò)結(jié)構(gòu) ,設(shè)置超參數(shù) ,初始化參數(shù) ,選擇優(yōu)化器 ,定義損失函數(shù) ,梯度下降更新參數(shù) 。具體步驟如下圖所示:
元學習會去學習所有需要由人去設(shè)置和定義的參數(shù)變量 。在這里參數(shù)變量 屬于集合為 ,則有:
不同的元學習,就要去學集合 中不同的元素,相應(yīng)的就會有不同的研究領(lǐng)域。
學習預(yù)處理數(shù)據(jù)集 :對數(shù)據(jù)進行預(yù)處理的時候,數(shù)據(jù)增強會增加模型的魯棒性,一般的數(shù)據(jù)增強方式比較死板,只是對圖像進行旋轉(zhuǎn),顏色變換,伸縮變換等。元學習可以自動地,多樣化地為數(shù)據(jù)進行增強,相關(guān)的代表作為 DADA。
論文名稱:DADA: Differentiable Automatic Data Augmentation
論文鏈接:https://arxiv.org/pdf/2003.03780v1.pdf
論文詳情:ECCV 2020
學習初始化參數(shù) :權(quán)重參數(shù)初始化的好壞可以影響模型最后的分類性能,元學習可以通過學出一個較好的權(quán)重初始化參數(shù)有助于模型在新的任務(wù)上進行學習。元學習學習初始化參數(shù)的代表作是 MAML(Model-Agnostic-Meta-Learning)。它專注于提升模型整體的學習能力,而不是解決某個具體問題的能力,訓練時,不停地在不同的任務(wù)上切換,從而達到初始化網(wǎng)絡(luò)參數(shù)的目的,最終得到的模型,面對新的任務(wù)時可以學習得更快。
論文名稱:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
論文鏈接:https://arxiv.org/pdf/1703.03400.pdf
論文詳情:ICML2017
學習網(wǎng)絡(luò)結(jié)構(gòu) :神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)設(shè)定是一個很頭疼的問題,網(wǎng)絡(luò)的深度是多少,每一層的寬度是多少,每一層的卷積核有多少個,每個卷積核的大小又該怎么定,需不需要 dropout 等等問題,到目前為止沒有一個定論或定理能夠清晰準確地回答出以上問題,所以神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)搜索 NAS 應(yīng)運而生。歸根結(jié)底,神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)其實是元學習地一個子類領(lǐng)域。值得注意的是,網(wǎng)絡(luò)結(jié)構(gòu)的探索不能通過梯度下降法來獲得,這是一個不可導(dǎo)問題,一般情況下會采用強化學習或進化算法來解決。
論文名稱:Neural Architecture Search with Reinforcement Learning
論文鏈接:https://arxiv.org/abs/1611.01578
論文詳情:ICLR 2017
學習選擇優(yōu)化器 :神經(jīng)網(wǎng)絡(luò)訓練的過程中很重要的一環(huán)就是優(yōu)化器的選取,不同的優(yōu)化器會對優(yōu)化參數(shù)時對梯度的走向有很重要的影響。熟知的優(yōu)化器有Adam,RMsprop,SGD,NAG等,元學習可以幫我們在訓練特定任務(wù)前選擇一個好的的優(yōu)化器,其代表作有:
論文名稱:Learning to learn by gradient descent by gradient descent
論文鏈接:https://arxiv.org/pdf/1606.04474.pdf
論文詳情:NIPS 2016
元學習訓練
元學習分為兩個階段,階段一是訓練任務(wù)訓練;階段二為測試任務(wù)訓練。對應(yīng)于一些論文的算法流程圖,訓練任務(wù)是在 outer loop 里,測試任務(wù)任務(wù)是在 inner loop 里。
2.1 階段一:訓練任務(wù)訓練
在訓練任務(wù)中給定 個子訓練任務(wù),每個子訓練任務(wù)的數(shù)據(jù)集分為 Support set 和 Query set。首先通過這 個子任務(wù)的 Support set 訓練 ,分別訓練出針對各自子任務(wù)的模型參數(shù) 。然后用不同子任務(wù)中的 Query set 分別去測試 的性能,并計算出預(yù)測值和真實標簽的損失 。接著整合這 個損失函數(shù)為 :
最后利用梯度下降法去求出 去更新參數(shù) ,從而找到最優(yōu)的超參設(shè)置;如果 不可求,則可以采用強化學習或者進化算法去解決。階段一中訓練任務(wù)的訓練過程被整理在如下的框圖中。
2.2 階段二:測試任務(wù)訓練
測試任務(wù)就是正常的機器學習的過程,它將數(shù)據(jù)集劃分為訓練集和測試集。階段一中訓練任務(wù)的目的是找到一個好的超參設(shè)置 ,利用這個先驗知識可以對特定的測試任務(wù)進行更好的進行訓練。階段二中測試任務(wù)的訓練過程被整理在如下的框圖中。
實例講解
上一節(jié)主要是給出了元學習兩階段的學習框架,這一節(jié)則是給出實例并加以說明。明確超參 為初始化權(quán)重參數(shù),通過元學習讓模型學習出一個較優(yōu)的初始化權(quán)重。假設(shè)在 AcrossTasks 中有 個子任務(wù),第 個子任務(wù) Support set 和 Query set 分別是 和 。第 個子任務(wù)的網(wǎng)絡(luò)權(quán)重參數(shù)為 ,元學習初始化的參數(shù)為 的原理圖如下所示,其具體過程為:
第一步:將所有子任務(wù)分類器的網(wǎng)絡(luò)結(jié)構(gòu)設(shè)置為一樣的,從 個子任務(wù)中隨機采樣出 個子任務(wù),并將初始權(quán)重 賦值給這 個網(wǎng)絡(luò)結(jié)構(gòu)。
第二步:采樣出的 個子任務(wù)分別在各自的 Support set 上進行訓練并更新參數(shù) 。在 MAML 中參數(shù) 更新一步,在 Reptile 中參數(shù) 更新多步。
第三步:利用上一步訓練出的 在 Query set 中進行測試,計算出各自任務(wù)下的損失函數(shù) 。
第四步:將不同子任務(wù)下的損失函數(shù) 進行整合得到 。
第五步:求出損失函數(shù) 關(guān)于 的導(dǎo)數(shù),并對初始化參數(shù) 進行更新。
循環(huán)以上個步驟,直到達到要求為止。
為了能夠更直觀的給出利用 更新參數(shù) 的過程,我硬著頭皮把梯度 的顯示表達式給寫了出來,具體形式如下所示:
從這個公式中也能隱約的發(fā)現(xiàn)整個訓練過程的縮影,它已經(jīng)把所有的變量都囊括了進去,這個公式也直接回答了一個問題,元學習自動學習權(quán)重參數(shù) ?是一個可導(dǎo)問題。
整理到這里有一個問題必須要被回答,元學習學習初始化權(quán)重的方法和預(yù)訓練方法有什么區(qū)別?為了能夠更直觀的對比這兩個方法的異同,將預(yù)訓練的過程整理為如下流程圖,具體的過程為:
第一步:前提只有一個神經(jīng)網(wǎng)路模型其初始化權(quán)重參數(shù)為 ,從 個子任務(wù)中隨機采樣出 個子任務(wù)。
第二步:神經(jīng)網(wǎng)絡(luò)模型在采樣出的 個子任務(wù)中進行訓練,得到不同子任務(wù)中的損失 。
第三步:將不同子任務(wù)下的損失函數(shù) 進行整合得到 。
第四步:求出損失函數(shù) 關(guān)于 的導(dǎo)數(shù),并對初始化參數(shù) 進行更新。
循環(huán)以上個步驟,直到達到要求為止。對應(yīng)的在預(yù)訓練過程中,梯度 的表達式為:
可以發(fā)現(xiàn)在相同的網(wǎng)絡(luò)結(jié)構(gòu)下,預(yù)訓練是只有一套模型參數(shù)在不同的任務(wù)中進行訓練,元學習是在不同的任務(wù)中有不同的模型參數(shù)進行訓練。對比二者的梯度公式可以發(fā)現(xiàn),預(yù)訓練過程簡單粗暴它想找到一個在所有任務(wù)(實際情況往往是大多數(shù)任務(wù))上都表現(xiàn)較好的一個初始化參數(shù),這個參數(shù)要在多數(shù)任務(wù)上當前表現(xiàn)較好。元學習過程相對繁瑣,但它更關(guān)注的是初始化參數(shù)未來的潛力。
特別鳴謝
感謝 TCCI 天橋腦科學研究院對于 PaperWeekly 的支持。TCCI 關(guān)注大腦探知、大腦功能和大腦健康。
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學術(shù)熱點剖析、科研心得或競賽經(jīng)驗講解等。我們的目的只有一個,讓知識真正流動起來。
📝?稿件基本要求:
? 文章確系個人原創(chuàng)作品,未曾在公開渠道發(fā)表,如為其他平臺已發(fā)表或待發(fā)表的文章,請明確標注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發(fā)送,要求圖片清晰,無版權(quán)問題
? PaperWeekly 尊重原作者署名權(quán),并將為每篇被采納的原創(chuàng)首發(fā)稿件,提供業(yè)內(nèi)具有競爭力稿酬,具體依據(jù)文章閱讀量和文章質(zhì)量階梯制結(jié)算
📬?投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請備注即時聯(lián)系方式(微信),以便我們在稿件選用的第一時間聯(lián)系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
🔍
現(xiàn)在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關(guān)注」訂閱我們的專欄吧
·
總結(jié)
以上是生活随笔為你收集整理的一文通俗讲解元学习(Meta-Learning)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 余额宝体验金在哪里
- 下一篇: 多篇顶会看个体因果推断(ITE)的前世今