通过互信息思想来缓解类别不平衡问题
?PaperWeekly 原創 ·?作者|蘇劍林
學校|追一科技
研究方向|NLP、神經網絡
類別不平衡問題,也叫“長尾問題”,是機器學習面臨的常見問題之一,尤其是來源于真實場景下的數據集,幾乎都是類別不平衡的。大概在兩年前,筆者也思考過這個問題,當時正好對“互信息”相關的內容頗有心得,所以構思了一種基于互信息思想的解決辦法,但又想了一下,那思路似乎過于平凡,所以就沒有深究。
然而,前幾天在 arxiv 上刷到 Google 的一篇文章 Long-tail learning via logit adjustment [1],意外地發現里邊包含了跟筆者當初的構思幾乎一樣的方法,這才意識到當初放棄的思路原來還能達到 SOTA 的水平。于是結合這篇論文,將筆者當初的構思過程整理于此,希望不會被讀者嫌棄“馬后炮”。
問題描述
這里主要關心的是單標簽的多分類問題,假設有 共 K 個候選類別,訓練數據為 ,建模的分布為 ,那么我們的優化目標是最大似然,或者說最小化交叉熵,即:
通常來說,我們建立的概率模型最后一步都是 softmax,假設 softmax 之前的結果為 (即 logits),那么:
所謂類別不均衡,就是指某些類別的樣本特別多,就好比“20% 的人占據了 80% 的財富”一樣,剩下的類別數很多,但是總樣本數很少,如果從高到低排序的話,就好像帶有一條很長的“尾巴”,所以叫做長尾現象。
這種情況下,我們訓練的時候采樣一個 batch,很少有機會采樣到低頻類別,因此很容易被模型忽略了低頻類。但評測的時候,通常我們又更關心低頻類別的識別效果,這便是矛盾之處。
常見思路
常見的思路大家應該也有所聽說,大概就是三個方向:
1. 從數據入手,通過過采樣或降采樣等手段,使得每個 batch 內的類別變得更為均衡一些;
2. 從 loss 入手,經典的做法就是類別 y 的樣本 loss 除以類別出現的頻率p(y);
3. 從結果入手,對正常訓練完的模型在預測階段做些調整,更偏向于低頻類別,比如正樣本遠少于負樣本,我們可以把預測結果大于 0.2(而不是 0.5)都視為正樣本。
Google 的原論文中對這三個方向的思路也列舉了不少參考文獻,有興趣調研的讀者可以直接閱讀原論文,另外,知乎上的文章《Long-Tailed Classification (2) 長尾分布下分類問題的最新研究》[2] 也對該問題進行了介紹,讀者也可以參考閱讀。
學習互信息
回想一下,我們是怎么斷定某個分類問題是不均衡的呢?顯然,一般的思路是從整個訓練集里邊統計出各個類別的頻率 p(y),然后發現 p(y) 集中在某幾個類別中。所以,解決類別不平衡問題的重點,就是如何把這個先驗知識 p(y) 融入模型之中。
在之前構思詞向量模型(如文章《更別致的詞向量模型(二):對語言進行建模》[3])的時候,我們就強調過,相比擬合條件概率,如果模型能直接擬合互信息,那么將會學習到更本質的知識,因為互信息才是揭示核心關聯的指標。
但是擬合互信息沒那么容易訓練,容易訓練的是條件概率,直接用交叉熵 進行訓練就行了。所以,一個比較理想的想法就是:如何使得模型依然使用交叉熵為 loss,但本質上是在擬合互信息?
在公式 (2)?中,我們是建模了:
現在我們改為建模互信息,那么那也就是希望:
按照右端的形式重新進行 softmax 歸一化,那么就有:
或者寫成 loss 形式:
原論文稱之為 logit adjustment loss。如果更加一般化,那么還可以加個調節因子 :
一般情況下, 的效果就已經接近最優了。如果 的最后一層有 bias 項的話,那么最簡單的實現方式就是將 bias 項初始化為 。也可以寫在損失函數中:
import?numpy?as?np import?keras.backend?as?Kdef?categorical_crossentropy_with_prior(y_true,?y_pred,?tau=1.0):"""帶先驗分布的交叉熵注:y_pred不用加softmax"""prior?=?xxxxxx??#?自己定義好prior,shape為[num_classes]log_prior?=?K.constant(np.log(prior?+?1e-8))for?_?in?range(K.ndim(y_pred)?-?1):log_prior?=?K.expand_dims(log_prior,?0)y_pred?=?y_pred?+?tau?*?log_priorreturn?K.categorical_crossentropy(y_true,?y_pred,?from_logits=True)def?sparse_categorical_crossentropy_with_prior(y_true,?y_pred,?tau=1.0):"""帶先驗分布的稀疏交叉熵注:y_pred不用加softmax"""prior?=?xxxxxx??#?自己定義好prior,shape為[num_classes]log_prior?=?K.constant(np.log(prior?+?1e-8))for?_?in?range(K.ndim(y_pred)?-?1):log_prior?=?K.expand_dims(log_prior,?0)y_pred?=?y_pred?+?tau?*?log_priorreturn?K.sparse_categorical_crossentropy(y_true,?y_pred,?from_logits=True)結果分析
很明顯 logit adjustment loss 也屬于調整 loss 方案之一,不同的是它是在 里邊調整權重,而常規的思路則是在 外調整。至于它的好處,就是互信息的好處:互信息揭示了真正重要的關聯,所以給 logits 補上先驗分布的 bias,能讓模型做到“能靠先驗解決的就靠先驗解決,先驗解決不了的本質部分才由模型解決”。
在預測階段,根據不同的評測指標,我們可以制定不同的預測方案。從《函數光滑化雜談:不可導函數的可導逼近》[4] 可以知道,對于整體準確率而言,我們有近似:
其中 是驗證集。所以如果不考慮類別不均衡情況,追求更高的整體準確率,那么對于每個 x 我們直接輸出 最大的類別即可。但如果我們希望每個類的準確率都盡可能高,那么我們將上式改寫成:
其中 ,也標簽為 y 的 x 的集合,等號右邊事實上就是先將同一個 y 的項合并起來。我們知道“整體準確率=每一類的準確率的加權平均”,而上式正好具有同樣的形式,所以括號里邊的 就是“每一類的準確率”的一個近似了。
因此,如果我們希望每一類的準確率都盡可能高,我們則要輸出使得 最大的類別(不加權)。結合 的形式,我們有結論:
第一種其實就是輸出條件概率最大者,而第二種就是輸出互信息最大者,按具體需求選擇。
至于詳細的實驗結果,大家可以自行看論文,總之就是好到有點意外:
▲ 原論文的實驗結果文末小結
本文簡單介紹了一種基于互信息思想的類別不平衡處理辦法,該方案以前筆者也曾經構思過,不過沒有深究,而最近 Google 的一篇論文也給出了同樣的方法,遂在此簡單記錄分析一下,最后 Google 給出的實驗結果顯示該方法能達到 SOTA 的水平。
參考文獻
[1]?https://arxiv.org/abs/2007.07314
[2] https://zhuanlan.zhihu.com/p/158638078
[3] https://kexue.fm/archives/4669
[4] https://kexue.fm/archives/6620
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
?????來稿標準:
? 稿件確系個人原創作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發,請在投稿時提醒并附上所有已發布鏈接?
? PaperWeekly 默認每篇文章都是首發,均會添加“原創”標志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發送?
? 請留下即時聯系方式(微信或手機),以便我們在編輯發布時和作者溝通
????
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結
以上是生活随笔為你收集整理的通过互信息思想来缓解类别不平衡问题的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度解读NLP文本情感分析Pipelin
- 下一篇: 新存科技公布国产最大容量新型 3D 存储