FlatNCE:小批次对比学习效果差的原因竟是浮点误差?
?PaperWeekly 原創 ·?作者?|?蘇劍林
單位?|?追一科技
研究方向?|?NLP、神經網絡
自 SimCLR [1] 在視覺無監督學習大放異彩以來,對比學習逐漸在 CV 乃至 NLP 中流行了起來,相關研究和工作越來越多。標準的對比學習的一個廣為人知的缺點是需要比較大的 batch_size(SimCLR 在 batch_size=4096 時效果最佳),小 batch_size 的時候效果會明顯降低,為此,后續工作的改進方向之一就是降低對大 batch_size 的依賴。那么,一個很自然的問題是:標準的對比學習在小 batch_size 時效果差的原因究竟是什么呢??
近日,一篇名為 Simpler, Faster, Stronger: Breaking The log-K Curse On Contrastive Learners With FlatNCE 對此問題作出了回答:因為浮點誤差。看起來真的很讓人難以置信,但論文的分析確實頗有道理,并且所提出的改進 FlatNCE 確實也工作得更好,讓人不得不信服。
論文標題:
Simpler, Faster, Stronger: Breaking The log-K Curse On Contrastive Learners With FlatNCE
論文作者:
Junya Chen, Zhe Gan, Xuan Li, Qing Guo, Liqun Chen, Shuyang Gao, Tagyoung Chung, Yi Xu, Belinda Zeng, Wenlian Lu, Fan Li, Lawrence Carin, Chenyang Tao
論文鏈接:
https://arxiv.org/abs/2107.01152
細微之處
接下來,筆者將按照自己的理解和記號來介紹原論文的主要內容。對比學習(Contrastive Learning)就不幫大家詳細復習了,大體上來說,對于某個樣本 x,我們需要構建 K 個配對樣本 ,其中 是正樣本而其余都是負樣本,然后分別給每個樣本對 打分,分別記為 ,對比學習希望拉大正負樣本對的得分差,通常直接用交叉熵作為損失:
簡單起見,后面都記 。在實踐時,正樣本通常是數據擴增而來的高相似樣本,而負樣本則是把 batch 內所有其他樣本都算上,因此大致上可以認為負樣本是隨機選擇的 K-1 個樣本。這就說明,正負樣本對的差距還是很明顯的,因此模型很容易做到 ,也即 。于是,當 batch_size 比較小的時候(等價于 K 比較小), 也會相當接近于 0,這意味著上述損失函數也會相當接近于 0。
損失函數接近于 0,通常也意味著梯度接近于 0 了,然而,這不意味著模型的更新量就很小了。因為當前對比學習用的都是自適應優化器如 Adam,它們的更新量大致形式為 梯度梯度梯度學習率,這就意味著,不管梯度多小,只要它穩定,那么更新量就會保持著 學習率 的數量級。
對比學習正是這樣的場景,要想 ,那么就要 ,但對比學習的打分通常是余弦值除以溫度參數,所以它是有界的, 是無法實現的,因此經過一定的訓練步數后,損失函數將會長期保持接近于 0 但又大于 0 的狀態。
然而, 的計算本身就存在浮點誤差,當 很接近于 0 時,浮點誤差可能比精確值還要大,然后 的計算也會存在浮點誤差,再然后梯度的計算也會存在浮點誤差,這一系列誤差累積下來,很可能導致最后算出來的梯度都接近于隨機噪聲了,而不能提供有效的更新指引。這就是原論文認為的對比學習在小 batch_size 時效果明顯變差的原因。
變微為著
理解了這個原因后,其實也就不難針對性地提出解決方案了。對損失函數做一階展開我們有:
也就是說,一定訓練步數之后,模型相當于以 為損失函數了。當然,由于 ,即 是 的上界,所以就算一開始就以 為損失函數,結果也沒什么差別,現在主要還是解決的問題是 接近于 0 而導致了浮點誤差問題。剛才說了,自適應優化器的更新量大致上都是 梯度梯度梯度學習率 的形式,這意味著如果我們直接將損失函數乘以一個常數,那么理論上更新量是不會改變的,所以既然 過小,那么我們就將它乘以一個常數放大就好了。
乘以什么好呢?比較直接的想法是損失函數不能過小,也不能過大,控制在 級別最好,所以我們干脆乘以 的倒數,也就是以:
為損失函數。這里 是 stop_gradient 的意思(原論文稱為 detach),也就是把分母純粹當成一個常數,求梯度的時候只需要對分子求。這就是原論文提出的替代方案,稱為 FlatNCE。
不過,上述帶 算子形式的損失函數畢竟不是我們習慣的形式,我們可以轉換一下。觀察到:
也就是說, 作為損失函數提供的梯度跟 作為損失函數的梯度是一模一樣的,因此我們可以把損失函數換為不帶 算子的 :
相比于交叉熵,上述損失就是在 運算中去掉了正樣本對的得分 。注意到 通常可以有效地計算,浮點誤差不會占主導,因此我們用上述損失函數取代交叉熵,理論上跟交叉熵是等效的,而實踐上在小 batch_size 時效果比交叉熵要好。此外,需要指出的是,上式結果不一定是非負的,因此換用上述損失函數后在訓練過程中出現負的損失值也不需要意外,這是正常現象。
實驗評估
分析似乎有那么點道理,那么事實是否有效呢?這自然是要靠實驗來說話了。不出意料,FlatNCE 確實工作得非常好。
原論文的實驗都是 CV 的,主要是把 SimCLR 的損失換為 FlatNCE 進行實驗,對應的結果稱為 FlatCLR。其中,我們最關心的大概是 FlatNCE 是否真的解決了對大 batch_size 的依賴問題,下面的圖像則作出了肯定回答:
▲ 不同 batch_size 下 SimCLR 與 FlatCLR 對比圖
下面則是 SimCLR 和 FlatCLR 在各個任務上的結果對比,顯示出 FlatCLR 更好的性能:
▲?SimCLR 和 FlatCLR 在各個任務上的對比
吹毛求疵
總的來說,原論文的結果非常有創造性,“浮點誤差”這一視角非常“刁鉆”但也相當精準,讓人不得不點贊。
直觀來看,原來交叉熵的目標是“正樣本得分與負樣本得分的差盡量大”,這對于常規的分類問題是沒問題的,但對于對比學習來說還不夠,因為對比學習目的是學習特征,除了正樣本要比負樣本得分高這種“粗”特征外,負樣本之間也要繼續對比以學習更精細的特征;FlatNCE 的目標則是“正樣本的得分要盡量大,負樣本的得分要盡量小”,也即從相對值的學習變成了絕對值的學習,從而使得正負樣本拉開一定距離后,依然能夠繼續優化,而不至于過早停止(對于非自適應優化器),或者讓浮點誤差帶來的噪聲占了主導(對于自適應優化器)。
然而,原論文的某些內容設置也不得不讓人吐槽。比如,論文花了較大的篇幅討論互信息的估計,但這跟論文主體并無實質關聯,加大了讀者的理解難度。當然,paper 跟科普不一樣,為了使文章更充實而增加額外的理論推導也無可厚非,只是如果能更突出浮點誤差部分的分析更好。然后,論文最讓我不能理解的地方是直接以式(3)為最終結果,這種帶“stop_gradient”的表述方式雖然算不上難,但也不友好,通常來說這種方式是難以尋求原函數的時候才“不得不”使用的,但 FlatNCE 顯然不是這樣。
總結全文
本文介紹了對比學習的一個新工作,該工作分析了小批次對比學習時交叉熵的浮點誤差問題,指出這可能是小批次對比學習效果差的主要原因,并且針對性地提出了改進的損失函數 FlatNCE,實驗表明基于 FlatNCE 的對比學習確實能緩解對大 batch_size 的依賴,并且能獲得更好的效果。
參考文獻
[1] https://arxiv.org/abs/2002.05709
特別鳴謝
感謝 TCCI 天橋腦科學研究院對于 PaperWeekly 的支持。TCCI 關注大腦探知、大腦功能和大腦健康。
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析、科研心得或競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。
?????稿件基本要求:
? 文章確系個人原創作品,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發送,要求圖片清晰,無版權問題
? PaperWeekly 尊重原作者署名權,并將為每篇被采納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算
?????投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請備注即時聯系方式(微信),以便我們在稿件選用的第一時間聯系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
????
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
·
總結
以上是生活随笔為你收集整理的FlatNCE:小批次对比学习效果差的原因竟是浮点误差?的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 支付宝兑换的优酷会员怎么用
- 下一篇: 油价上涨的因素有哪些