元学习之《On First-Order Meta-Learning Algorithms》论文详细解读
元學(xué)習(xí)系列文章
文章目錄
- 引言
- On First-Order Meta-Learning Algorithms
- 偽算法
- 數(shù)學(xué)過程
- 訓(xùn)練過程
- 實(shí)驗(yàn)
- 核心代碼
- OpenAI Demo
- 幾點(diǎn)思考
- 參考資料
引言
上一篇博客對(duì)論文 MAML 做了詳細(xì)解讀,MAML 是元學(xué)習(xí)方向 optimization based 的開篇之作,還有一篇和 MAML 很像的論文 On First-Order Meta-Learning Algorithms,該論文是大名鼎鼎的 OpenAI 的杰作,OpenAI 對(duì) MAML 做了簡(jiǎn)化,但效果卻優(yōu)于 MAML,具體做了什么簡(jiǎn)化操作,請(qǐng)往下看😀。
On First-Order Meta-Learning Algorithms
這篇論文的標(biāo)題就很針對(duì) MAML,MAML 中有一個(gè)重要的特點(diǎn),就是在求梯度時(shí),為了加速放棄了二階求導(dǎo),使用一階微分近似進(jìn)行代替,雖然效果上相差不大,但總感覺少了點(diǎn)什么。這篇論文的標(biāo)題上來就聲稱我們是一階的 metalearning 方法,而且剛好是在 MAML 發(fā)表的下一年(2018)發(fā)表在 ICML 會(huì)議的,從標(biāo)題上也是賺慢了噱頭。
還有個(gè)有意思的事情,OpenAI 把論文中的算法稱之為 Reptile, 但是也沒有解釋為什么叫這個(gè),論文中也沒看出來和 Reptile 有什么關(guān)聯(lián),感興趣的讀者,可以去深究一下。
說了一堆廢話,下面開始進(jìn)入正題。
偽算法
貼一張論文中的官方算法:
先來解釋一下:
1 首先初始化一個(gè)網(wǎng)絡(luò)模型的所有參數(shù) ? \phi ?
2 迭代 N 次,進(jìn)行訓(xùn)練,每次迭代執(zhí)行:
- 2.1 隨機(jī)抽樣一個(gè)任務(wù) T,用網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練,對(duì)應(yīng)的loss 是 L t L_t Lt?,訓(xùn)練結(jié)束后的參數(shù)是 ? ~ \widetilde{\phi} ? ?
- 2.2,在參數(shù) ? \phi ?上使用 SGD 或 Adam 執(zhí)行K次梯度下降更新,得到 ? ~ = U t k ( ? ) \widetilde{\phi}={U}^{k}_{t}(\phi) ? ?=Utk?(?)
- 2.3 用 ? ~ \widetilde{\phi} ? ?更新網(wǎng)絡(luò)模型模型參數(shù), ? = ? + ? ( ? ~ ? ? ) \phi=\phi+\epsilon(\widetilde{\phi}-\phi) ?=?+?(? ???)
3 完成上述N次迭代訓(xùn)練,則結(jié)束整個(gè)過程
從上面的算法中可以看出,Reptile 是在每個(gè)單獨(dú)的任務(wù)執(zhí)行K次訓(xùn)練后,就開始真正更新網(wǎng)絡(luò)模型的參數(shù)(Meta),更新方式不是梯度下降,但是和梯度下降公式長(zhǎng)得很像,是用上一次的參數(shù) ? \phi ?和K次后的參數(shù) ? ~ \widetilde{\phi} ? ?的差來更新,更新的步長(zhǎng)是 ? \epsilon ?。在這個(gè)過程中,只有一階求導(dǎo)的計(jì)算,就是在任務(wù)內(nèi)部執(zhí)行K次更新的過程中用到的隨機(jī)梯度下降,這也是為什么標(biāo)題中叫 First-Order 的原因。
從這就可以看出和 MAML 算法的不同了:
這里說的meta參數(shù),就是真正更新網(wǎng)絡(luò)模型參數(shù)的過程
數(shù)學(xué)過程
上面只是簡(jiǎn)單介紹了 Reptile 的算法思想,下面從數(shù)學(xué)過程上來理解下它的更新過程,先來設(shè)定幾個(gè)符號(hào):
? \phi ?代表網(wǎng)絡(luò)模型初始參數(shù), ? , η \epsilon,\eta ?,η分別代表 meta 更新的學(xué)習(xí)率和 task 更新的學(xué)習(xí)率, N N N是meta訓(xùn)練的 batch_size,即 meta 的一個(gè)bach有 N 個(gè)task,每個(gè)task內(nèi)部執(zhí)行K次訓(xùn)練,N個(gè)任務(wù)都訓(xùn)練完,再來更新meta參數(shù)。按照上面的算法過程,meta的一個(gè)batch訓(xùn)練完之后,網(wǎng)絡(luò)模型的參數(shù)是:
? = ? + ? 1 N ∑ i = 1 N ( ? i ~ ? ? ) = ? + ? ( W ? ? ) \begin{aligned} \phi &= \phi +\epsilon \frac{1}{N}\sum_{i=1}^{N}\left ( \tilde{\phi_i } -\phi\right )\\ &= \phi +\epsilon \left ( W-\phi \right )\\ \end{aligned} ??=?+?N1?i=1∑N?(?i?~???)=?+?(W??)?
其中 W W W是每個(gè)任務(wù)最后參數(shù)的平均值,上述公式再進(jìn)行展開就是這樣:
假設(shè)N=2,K=3,即meta每次訓(xùn)練的一個(gè)batch 有2個(gè)task,每個(gè)task內(nèi)部進(jìn)行3此迭代,則 meta每次更新模型參數(shù)的公式為:
訓(xùn)練過程
上面公式的最后一行,又變成了熟悉的梯度下降,只不過梯度方向是每個(gè)任務(wù)內(nèi)部更新的幾次梯度方向的和。meta 模型的參數(shù)更新過程,在幾何上就是這樣的:
動(dòng)圖看的更加清晰些,其中綠色代表第一個(gè)任務(wù),三個(gè)綠色箭頭代表三次更新時(shí)的梯度方向,可以看到,Reptile的模型就是朝著每個(gè)任務(wù)的梯度和的方向上不斷地進(jìn)行更新。
還記得 MAML 是怎樣更新的嗎?不記得的話,請(qǐng)翻看上一篇博客。還是同樣的設(shè)置,MAML 的更新過程如下:
即 MAML 是在每個(gè)任務(wù)最后一個(gè)梯度的方向上進(jìn)行更新,而 Reptile 是在每個(gè)任務(wù)幾個(gè)梯度和的方向上進(jìn)行更新。
實(shí)驗(yàn)
實(shí)驗(yàn)設(shè)置和 MAML 論文中的設(shè)置一樣,回歸任務(wù)以擬合正弦函數(shù)為例,分類任務(wù)以 MiniImagenet 數(shù)據(jù)和 omniglot 數(shù)據(jù)的圖片分類為例,詳細(xì)設(shè)置就不再贅述了,直接看實(shí)驗(yàn)結(jié)果:
上半部分的圖是正弦函數(shù)的擬合結(jié)果,(b)是MAML的結(jié)果,C是Reptile的結(jié)果,橘黃色線是微調(diào)32次之后的樣子,綠色線是真實(shí)分布,可以看到 Reptile和MAML的結(jié)果相當(dāng),都能擬合到真實(shí)分布的樣子,硬要一較高下的話,那就是 Reptile稍好一些。
下半部分圖是在 MiniImagenet 分類數(shù)據(jù)上的結(jié)果,作者也對(duì)比了一階近似 MAML和二階MAML的結(jié)果,從圖中可以看出,Reptile的準(zhǔn)確率至少要高出1個(gè)百分點(diǎn)。
在論文中作者還對(duì)比了一個(gè)有意思的實(shí)驗(yàn),Reptile 既然可以在 g 1 + g 2 + g 3 g_1+g_2+g_3 g1?+g2?+g3? 的梯度方向上更新,那么如果在其它梯度的組合方向上去更新,結(jié)果會(huì)怎樣呢?比如 g 1 + g 2 g_1+g_2 g1?+g2? 等方向,作者也針對(duì)不同梯度的組合進(jìn)行了實(shí)驗(yàn),實(shí)驗(yàn)結(jié)果如下:
橫軸是meta迭代次數(shù),縱軸是準(zhǔn)確率,不同顏色的曲線代表不同的梯度組合,可以明顯的看到最下面的藍(lán)色曲線準(zhǔn)確率最低,藍(lán)色曲線代表在 g 1 g_1 g1? 第一個(gè)梯度方向上去更新,其實(shí)就是模型預(yù)訓(xùn)練的過程,以所有訓(xùn)練任務(wù)的 loss 為準(zhǔn)進(jìn)行更新。其他顏色的曲線都代表用若干次之后的 loss 來更新參數(shù),最上面的那條曲線代表 Reptile,即用 g 1 + g 2 + g 3 + g 4 g_1+g_2+g_3+g_4 g1?+g2?+g3?+g4? 的梯度方向進(jìn)行更新,只使用 g 4 g_4 g4? 的那條曲線代表 MAML。
核心代碼
Reptile 的論文代碼也是開源的,而且代碼很簡(jiǎn)介規(guī)范,不愧是 OpenAI 出品。建議感興趣的讀者去看下論文源碼,不僅能更好的理解論文思想,對(duì)工程能力的提升也很有幫助,包括代碼風(fēng)格、模塊化、組織架構(gòu)、邏輯實(shí)現(xiàn)等都有很多值得借鑒的地方。關(guān)于源代碼有疑問的話,可以私信聯(lián)系我。這里只貼一點(diǎn)核心的訓(xùn)練更新代碼,對(duì)應(yīng)上面的數(shù)學(xué)過程:
代碼文件見 reptile.py
# 取出網(wǎng)絡(luò)模型的最新參數(shù)old_vars = self._model_state.export_variables()# 保存一個(gè) meta batch 里,每個(gè) task 更新 K 次后的參數(shù)new_vars = []for _ in range(meta_batch_size):# 抽樣出一個(gè) taskmini_dataset = _sample_mini_dataset(dataset, num_classes, num_shots)for batch in _mini_batches(mini_dataset, inner_batch_size, inner_iters, replacement):# task 里面的訓(xùn)練,更新 inner_iters 次,相當(dāng)于公式中的Kinputs, labels = zip(*batch) # inner_iters 個(gè) batch,每個(gè) iter 使用一個(gè) batch ,里面的一次訓(xùn)練迭代if self._pre_step_op:self.session.run(self._pre_step_op)self.session.run(minimize_op, feed_dict={input_ph: inputs, label_ph: labels})# 一個(gè) task 內(nèi)部訓(xùn)練完的參數(shù)new_vars.append(self._model_state.export_variables())self._model_state.import_variables(old_vars)# 對(duì) meta_batch 個(gè) task 的最終參數(shù)進(jìn)行平均,相當(dāng)于公式中的 Wnew_vars = average_vars(new_vars)# 所有的 meta_batch 個(gè)任務(wù)都訓(xùn)練完, 更新一次 meta 參數(shù),并且把更新后的參數(shù)更新到計(jì)算圖中,下次訓(xùn)練從最新參數(shù)開始# 更新方式:old + scale*(new - old)self._model_state.import_variables(interpolate_vars(old_vars, new_vars, meta_step_size))OpenAI Demo
在 OpenAI 的官方博客 Reptile: A Scalable Meta-Learning Algorithm中,也有介紹這篇論文。該博客網(wǎng)頁中還有個(gè)有意思的 demo,大家可以試玩一下:
這個(gè) demo 的意思是,openAI 已經(jīng)用他們的 Reptile 算法訓(xùn)練了一個(gè)用于少樣本場(chǎng)景的3分類網(wǎng)絡(luò)模型,并且嵌入到了網(wǎng)頁中,用戶可以通過 demo 中的交互制作一個(gè)新的三分類任務(wù),并且這個(gè)任務(wù)只有三個(gè)訓(xùn)練樣本,也就是每個(gè)類下只有一個(gè)樣本,學(xué)名叫3-Way 1-shot,讓他們的模型在這三個(gè)樣本上進(jìn)行微調(diào)學(xué)習(xí),然后在右邊畫一個(gè)新的三個(gè)類別下的測(cè)試樣本,Reptile 模型會(huì)自動(dòng)給出它在三個(gè)類別下的概率。通過這個(gè) demo 來證明他們的模型確實(shí)有奇效,在新任務(wù)的幾個(gè)樣本上微調(diào)一下,就可以在該任務(wù)的測(cè)試集上取得很好的準(zhǔn)確率。
幾點(diǎn)思考
通過上面的 demo 可以得出一些結(jié)論:
參考資料
- https://arxiv.org/pdf/1803.02999.pdf
- https://github.com/openai/supervised-reptile
- https://www.bilibili.com/video/BV1Gb411n7dE?p=32
總結(jié)
以上是生活随笔為你收集整理的元学习之《On First-Order Meta-Learning Algorithms》论文详细解读的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。