NeurIPS 2021 | 通过寻找平坦最小值,克服小样本增量学习中的灾难性遗忘
?作者 | FlyingBug
單位 | 哈爾濱工業大學(深圳)
研究方向 | 小樣本學習
寫在篇首
本文分享的這篇論文是 NeurIPS?2021的一篇 Few-Shot 增量學習 (FSCIL) 文章,這篇文章通過固定 backbone 和 prototype 得到一個簡單的 baseline,發現這個 baseline 已經可以打敗當前 IL 和 IFSL 的很多 SOTA 方法,基于此通過借鑒 robust optimize 的方法,提出了在 base training 訓練時通過 flat local minima 來對后面的 session 進行 fine-tune novel classes,解決災難性遺忘問題。
論文標題:
Overcoming Catastrophic Forgetting in Incremental Few-Shot Learning by Finding Flat Minima
收錄會議:
NeurIPS 2021
論文鏈接:
https://arxiv.org/pdf/2111.01549.pdf
代碼鏈接:
https://github.com/moukamisama/F2M
Motivation
不同于現有方法在學習新任務時嘗試克服災難性遺忘問題,這篇文章提出在訓練 base classes 時就提出策略來解決這個問題。作者提出找到 base training function 的 flat local minima,最小值附近 loss 小,作者認為 base classes 分離地更好(直覺上,flat local minima 會比 sharp 的泛化效果更好,參閱下圖 [2])。
1.2 Contribution
作者發現一個簡單的 baseline model,只要在 base classes 上訓練,不在 new tasks 上進行適應,就超過了現有的 SOTA 方法,說明災難性遺忘問題非常嚴重。作者提出在 primitive stage 來解決災難性遺忘問題,通過在 base classes 上訓練時找到 flat minima region 并在該 region 內學習新任務,模型能夠更好地克服遺忘問題。
1.3 A Simple Baseline?
作者提出了一個簡單的 baseline,模型只在 base classes 上進行訓練,在后續的 session 上直接進行推理。
Training(t=1)
在session 1上對特征提取器進行訓練,并使用一個全連接層作為分類器,使用 CE Loss 作為損失函數,從session 2 () 開始將特征提取器固定住,不使用 novel classes 進行任何 fine-tune 操作。
Inference(test)
使用均值方式獲得每個類的 prototype,然后通過歐氏距離 采用最近鄰方式進行分類。分類器的公式如下:
其中 表示類別 的 prototype, 表示類別 的訓練圖片數量。同時作者將 中所有類的 prototypes 保存下來用于后續的 evaluation。
作者表示通過這種保存 old prototype 的方式就打敗了現有的 SOTA 方法,證明了災難性遺忘非常嚴重。
1.4 Method
核心想法就是在 base training 的過程中找到函數的 flat local minima ,并在后續的 few-shot session 中在 flat region 進行 fine-tune,這樣可以最大限度地保證在 novel classes 上進行 fine-tune 時避免遺忘知識。在后續增量 few-shot sessions () 中,在這個 flat region 進行 fine-tune 模型參數來學習 new classes。
1.4.1 尋找Base Training的flat local minima
為了找到 base training function 的近似 flat local minima,作者提出添加一些隨機噪聲到模型參數,噪聲可以被多次添加以獲得相似但不同的 loss function,直覺上,flat local minima 附近的參數向量有小的函數值。
假設模型的參數 , 表示特征提取網絡的參數, 表示分類器的參數。 表示一個有類標訓練樣本,損失函數 。我們的目標就是最小化期望損失函數。
?是數據分布 是噪聲分布, 和 是相互獨立的。
因此最小化期望損失是不可能的,所以這里我們最小化他的近似,empirical loss:
?是 , 是采樣次數。這個 loss 的前半部分是為了找到 flat region,它的特征提取網絡參數 可以很好地區分 base classes。第二部分是通過 MSE Loss 的設計為了讓 prototype 盡量保持不變, 避免模型遺忘過去的知識。
1.4.2 在Flat Region內進行IFSL?
作者認為雖然 flat region 很小,但對于 few-shot 的少量樣本來說,足夠對模型進行迭代更新。
通過歐氏距離使用基于度量的分類算法來 fine-tune 模型參數。
1.4.3 收斂性分析?
我們的目標是找到一個 flat region 使模型效果較好。然后,通過最小化期望損失(噪聲 和數據 的聯合分布)。為了近似這個期望損失,我們在每次迭代中多次從 采樣,并使用隨機梯度下降 (SGD) 優化目標函數。后面是相關的理論證明,感興趣的可以自行閱讀分析。
參考文獻
[1] Shi G, Chen J, Zhang W, et al. Overcoming Catastrophic Forgetting in Incremental Few-Shot Learning by Finding Flat Minima[J]. Advances in Neural Information Processing Systems, 2021, 34.?
[2] He H, Huang G, Yuan Y. Asymmetric valleys: Beyond sharp and flat local minima[J]. arXiv preprint arXiv:1902.00744, 2019.
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析、科研心得或競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。
📝?稿件基本要求:
? 文章確系個人原創作品,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發送,要求圖片清晰,無版權問題
? PaperWeekly 尊重原作者署名權,并將為每篇被采納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算
📬?投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請備注即時聯系方式(微信),以便我們在稿件選用的第一時間聯系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
🔍
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
·
總結
以上是生活随笔為你收集整理的NeurIPS 2021 | 通过寻找平坦最小值,克服小样本增量学习中的灾难性遗忘的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 空客330-200机型15年机龄长吗
- 下一篇: 当特种兵需要什么条件?