半监督学习之MixMatch
半監督學習之MixMatch
MixMatch
Unsupervised Data Augmentation for Consistency Training
半監督深度學習訓練和實現小Tricks
MixMatch: A Holistic Approach to Semi-Supervised Learning
1.解讀
超強半監督學習 MixMatch
此方法僅用少量的標記數據,就使半監督學習的預測精度逼近監督學習。
- 自洽正則化(Consistency Regularization)。自洽正則化的思路是,對未標記數據進行數據增廣,產生的新數據輸入分類器,預測結果應保持自洽。即同一個數據增廣產生的樣本,模型預測結果應保持一致。
x 是未標記數據,Augment(x) 表示對x做隨機增廣產生的新數據, [公式] 是模型參數,y 是模型預測結果。注意數據增廣是隨機操作,兩個 Augment(x) 的輸出不同。這個 L2 損失項,約束機器學習模型,對同一個圖像做增廣得到的所有新圖像,作出自洽的預測。
MixMatch 集成了自洽正則化。數據增廣使用了對圖像的隨機左右翻轉和剪切(Crop)。
- 最小化熵(Entropy Minimization)。許多半監督學習方法都基于一個共識,即分類器的分類邊界不應該穿過邊際分布的高密度區域。具體做法就是強迫分類器對未標記數據作出低熵預測。
MixMatch 使用 “sharpening” 函數,最小化未標記數據的熵。
- **傳統正則化(Traditional Regularzation)。**為了讓模型泛化能力更好,一般的做法對模型參數做 L2 正則化,SGD下L2正則化等價于Weight Decay。MixMaxtch 使用了 Adam 優化器,而之前有篇文章發現 Adam 和 L2 正則化同時使用會有問題,因此 MixMatch 從諫如流使用了單獨的Weight decay。
Mixup數據增強方法。從訓練數據中任意抽樣兩個樣本,構造混合樣本和混合標簽,作為新的增廣數據,
這種 MixMatch 方法在小數據上做半監督學習的精度,遠超其他同類模型。比如,在 CIFAR-10 數據集上,只用250個標簽,他們就將誤差減小了4倍(從38%降到11%)。在STL-10數據集上,將誤差降低了兩倍。
對比 MixMatch 使用 250 張標記圖片,就可以將測試誤差降低到 11.08,使用4000張標記圖片,可以將測試誤差降低到 6.24,應該算是大幅度超越使用GAN做半監督學習的效果。
具體實現
1.使用MixMatch算法,對一個Batch的標記數據X和一個Batch的未標記數據U做數據數據增強,分別得到一個Batch的增強數據X’和K個Batch的U’
X ′ , U ′ = M i x M a t c h { X , U , T , K , α } \mathcal {X’,U’=MixMatch\{X,U,T,K,\alpha\}} X′,U′=MixMatch{
X,U,T,K,α}
T,溫度參數(sharpen的超參數);K,對未標記的數據做K次隨機增強,α是Mixup的超參數
2.對X’和U’分別計算損失
|X|等于batch size,|U|等于K倍的batch size,L是分類類別數,H是CE
對于未標注的數據使用L2范數做損失因為L2比CE約束更加嚴格
3.最終的損失是兩者的加權
另一一篇博客
The Quiet Semi-Supervised Revolution
性能和標注數據量的關系
現在的趨勢是
2.論文閱讀
題目:MixMatch:一個半監督學習的整體(Holistic)方法
代碼
- 1.tensorflow
google-research/mixmatch
- 2.pytorch
YU1ut/MixMatch-pytorch
2.1摘要
半監督學習已被證明是一個強大的利用未標簽數據來減輕依賴于大型標簽數據集的范式(paradigm)。
MixMatch估計(guess)低熵的數據增強后的未標注樣本,然后使用Mixup將標注的數據和未標注的數據混合起來。
2.2介紹
SSL,Semi-supervised Learning
許多半監督的學習方法通過增加在未標注的數據上計算的損失項(loss term)來估計模型在沒見過的數據上泛化。
損失項分為3類(falls into one of three classes)
- Entropy Minimization,鼓勵模型在未標注的數據上輸出高置信度(confident predictions)的預測
- Consistency Regularization,鼓勵模型在其輸入受到干擾時產生相同的輸出分布
- Generic Regularization,減少模型過擬合
MixMatch優雅地統一了這些主流的方法(gracefully unifies these dominant approaches)
2.3相關工作
最近的一些SOTA的方法
-
Consistency Regularization
一致性/自洽正則化
數據增強將輸入進行轉換并且認為類別語義不受影響。
粗略地說,數據增強可以通過生成接近無限的新修改數據流來人為地擴展訓練集的大小。一致性正則化將數據增強應用于半監督學習,即分類器應該為未標注的例子輸出相同的類分布。更正式地說,一致性正則化強制一個未標記的樣本x應該和Augment(x)分為一類。
對于一個點x,過去地工作加了一個損失項
Augment(x)是一個隨機地變換,所以2個Augment(*)不等
“Mean Teacher”(2017)將其中一項替換為了模型參數值的滑動平均
MixMatch使用了一種一致性正則化的形式,通過對圖像使用標準的數據增強(隨機水平翻轉和裁切)
Jetbrains全家桶1年46,售后保障穩定
-
Entropy Minimization
許多SSL方法的基本假設是分類器的決策邊界不應該通過數據分布邊際的高密度區域(“非黑即白假設”,想想SVM的決策邊界)。一個強制實現的方法是要求分類器對未標記的數據輸出低熵的預測。
“Pseudo Label”通過對高置信度的結果變為1-hot標簽來隱式地實現低熵
MixMatch通過使用“sharpen”函數來隱式地達到低熵
-
Traditional Regularization
正則化值對模型施以約束來使之更難地記住訓練數據以希望對沒見過地數據泛化。
使用權重衰減來懲罰模型參數的L2范數,使用MixUp來估計樣本之間的凸行為(convex behavior)
2.4 MixMatch
MixMatch是一個”整合“的方法,有上面的主流SSL范式組成。
- 給定batch大小的標注數據和同樣大小的標注數據,記為 X , U \mathcal {X,U} X,U
- MixMatch產生一批增強后的數據和增強后的帶有“猜測”的標簽的增強后的非標注數據,記作 X ′ , U ′ \mathcal {X’,U’} X′,U′
- 使用 X ′ , U ′ \mathcal {X’,U’} X′,U′分別計算標注和未標注損失項
H ( p , q ) H(p,q) H(p,q)是分布p和q的交叉熵, T , K , α , λ U T,K,\alpha,\mathcal{\lambda_U} T,K,α,λU?是超參數,L是類別(X of labeled examples with one-hot targets (representingone of L possible labels)
算法:
算法的偽代碼
標簽“猜測”過程
隨機數據增強對未標注的數據使用K次,每次的增強后的圖片都被輸入分類器。然后,這些K個預測被”銳化“(“sharpened”)通過調整分布的溫度超參。
- 數據增強
對于標注數據生成一個batch size的增強結果,對于非標注數據,我們生成K*batch size的增強結果。對于非標注的數據,生成K個增強結果。使用這些獨立的增強結果來生成”猜測標簽”
- 標簽猜測
對于每個未標注的樣本,MixMatch使用模型的預測產生一個“猜測”的樣本標簽,這個猜測隨后會被用于非監督損失項。
計算K個增強的結果計算平均值:
q ˉ b = 1 K ∑ k = 1 K p m o d e l ( y ∣ u ^ b , k ; θ ) \bar q_b=\frac{1}{K}\sum^K_{k=1}p_{model}(y|\hat u_{b,k};\theta) qˉ?b?=K1?∑k=1K?pmodel?(y∣u^b,k?;θ)
通過對未標記的樣本進行增強獲得的人工結果來實現一致性正則化
使用一致性正則化會帶來域適應(cycleGAN)
- 銳化
通過銳化來減少標簽分布的熵。使用常用的方法來調整類分布的**”溫度“**
當T→0,輸出的結果趨于Dirac分布(one-hot)
- MixUp
將標注的樣本的標簽和非標注樣本的“猜測標簽”混合。
具體做法:
從beta分布中采樣得到權重λ
對于兩對數據標簽對 ( x 1 , p 1 ) , ( x 2 , p 2 ) (x_1,p_1),(x_2,p_2) (x1?,p1?),(x2?,p2?),mix后的結果為 ( x ′ , y ′ ) (x’,y’) (x′,y′)。 λ ′ \lambda’ λ′的作用是使得x’比x2更加靠近x1(使得標注得標簽占比更大)x
α是調整beta分布的超參數。
為了實現Mixup,首先收集所有得增強后的標注樣本標簽和增強后的未標注樣本的“猜測標簽”,然后將結果混洗后作為Mixup的數據源 W \mathcal W W,然后將標注的數據和等量的 W \mathcal W W作為Mixup的輸入得到結果 X ^ \mathcal {\hat X} X^,然后將剩余的 W \mathcal W W中的數據和未標記的帶“猜測標簽”的數據作為Mixup的輸入。
SSL對未標注數據使用L2損失的原因是對不正確的預測不敏感。
消融實驗的結果:
參數EMA似乎是負面的影響
2.5實踐細節
-
超參數的設置
- T=0.5
- K=2
- α=0.75
- λ U \mathcal {\lambda_U} λU?=100
訓練前的16,000中,線性地將 λ U \mathcal{\lambda_U} λU?提高到最大值。
-
模型
- Wide ResNet-28
-
學習率地設置
- 不使用學習率衰減而是使用模型參數值的滑動平均,衰減率為0.999
總結
以上是生活随笔為你收集整理的半监督学习之MixMatch的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 在Ubuntu上安装SAP Cloud
- 下一篇: WebClient UI删除搜索条件的后