长文总结半监督学习(Semi-Supervised Learning)
?PaperWeekly 原創 ·?作者|燕皖
單位|淵亭科技
研究方向|計算機視覺、CNN
在現實生活中,無標簽的數據易于獲取,而有標簽的數據收集起來通常很困難,標注也耗時和耗力。在這種情況下,半監督學習(Semi-Supervised Learning)更適用于現實世界中的應用,近來也已成為深度學習領域熱門的新方向,該方法只需要少量有帶標簽的樣本和大量無標簽的樣本,而本文主要介紹半監督學習的三個基本假設和三類方法。
Base Assumptions
在什么假設下可以應用半監督算法呢?半監督算法僅在數據的結構保持不變的假設下起作用,沒有這樣的假設,不可能從有限的訓練集推廣到無限的不可見的集合。具體地假設有:
1.1 The Smoothness Assumption
如果兩個樣本 x1,x2 相似,則它們的相應輸出 y1,y2 也應如此。這意味著如果兩個輸入相同類,并且屬于同一簇,則它們相應的輸出需要相近,反之亦成立。
1.2 The Cluster Assumption
假設輸入數據點形成簇,每個簇對應于一個輸出類,那么如果點在同一個簇中,則它們可以認為屬于同一類。聚類假設也可以被視為低密度分離假設,即:給定的決策邊界位于低密度地區。兩個假設之間的關系很容易看出。
一個高密度區域,可能會將一個簇分為兩個不同的類別,從而產生屬于同一聚類的不同類,這違反了聚類假設。在這種情況下,我們可以限制我們的模型在一些小擾動的未標記數據上具有一致的預測,以將其判定邊界推到低密度區域。
1.3 The Manifold Assumption
(a)輸入空間由多個低維流形組成,所有數據點均位于其上;
(b)位于同一流形上的數據點具有相同標簽。
Consistency Regularization
深度半監督學習的一個新的研究方向是利用未標記的數據來強化訓練模型,使其符合聚類假設,即學習的決策邊界必須位于低密度區域。這些方法基于一個簡單的概念,即如果對一個未標記的數據應用實際的擾動,則預測不應發生顯著變化,因為在聚類假設下,具有不同標簽的數據點在低密度區域分離。
具體來說,給定一個未標記的數據點 及其擾動的形式 ,目標是最小化兩個輸出之間的距離:
流行的距離測量 d 通常是均方誤差(MSE),Kullback-Leiber 散度(KL)和 Jensen-Shannon 散度(JS),我們可以按以下方式計算這些度量,其中 。
具體到每一種算法,核心思想是沒有變化的,即最小化未標記數據與其擾動輸出兩者之間的距離,但計算輸出的形式上有很多變化。
2.1 Pi-Model (ICLR2017)
論文標題:
Temporal Ensembling for Semi-Supervised Learning
論文鏈接:
https://openreview.net/forum?id=BJ6oOfqge¬eId=BJ6oOfqge
代碼鏈接:
https://github.com/smlaine2/tempens
具體來說,由于正則化技術(例如 data augment 和 dropout)通常不會改變模型輸出的概率分布,Pi-Model 正是利用神經網絡中這種預測函數的特性,對于任何給定的輸入 x,使用不同的正則化然后預測兩次,而目標是減小兩次預測之間的距離, 提升模型在不同擾動下的一致性,Pi-Model 使用 MSE 做為兩個概率分布之間的損失函數。
訓練過程如上圖所示:對每一個參與訓練的樣本,在訓練階段,Pi-Model 需要進行兩次前向推理。此處的前向運算,包含一次隨機增強變換和不做增強的前向運算。由于增強變換是隨機的,同時模型采用了 Dropout,這兩個因素都會造成兩次前向運算結果的不同。
損失函數:由兩部分構成,如下所示,其中第一項含有一個時變系數 w,用來逐步釋放此項的權重,x 是未標記數據,由兩次前向運算結果的均方誤差(MSE)構成。第二項由交叉熵構成,x 是標記數據,y 是對應標簽,僅用來評估有標簽數據的誤差。可見,第一項即是用來實現一致性正則。
2.2 Temporal Ensembling (ICLR2017)
論文標題:
Temporal Ensembling for Semi-Supervised Learning
論文鏈接:
https://openreview.net/forum?id=BJ6oOfqge¬eId=BJ6oOfqge
代碼鏈接:
https://github.com/smlaine2/tempens
在 Pi-Model 的基礎上進一步提出了 Temporal Ensembling,其整體框架與 Pi-model 類似,在獲取無標簽數據的處理上采用了相同的思想,唯一的不同是:
在目標函數的無監督一項中, Pi-Model 是兩次前向計算結果的均方差,而在 temporal ensembling 模型中,使用時序組合模型,采用的是當前模型預測結果與歷史預測結果的平均值做均方差計算。有效地保留歷史了信息,消除了擾動并穩定了當前值。
如上圖所示,對于一個目標 ,在每次訓練迭代中,當前輸出 通過 EMA(exponential moving averag,指數滑動平均)更新被累加到整體輸出中 yema:
而損失函數與 Pi-Model 相似。相對于 Pi-Model,Temporal Ensembling 有兩方面的好處:
用空間來換取時間,總的前向推理次數減少了一半,因而減少了訓練時間;
通過歷史預測做平均,有利于平滑單次預測中的噪聲。
2.3 Mean teachers (NIPS 2017)
論文標題:
Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results
論文鏈接:
https://arxiv.org/abs/1703.01780
代碼鏈接:
https://github.com/CuriousAI/mean-teacher
如上圖所示,Mean Teachers 則是 Temporal Ensembling 的改進版,Temporal Ensembling 對模型的預測值進行 EMA(exponential moving averag),而 Mean Teachers 采用了對 studenet 模型權重進行 EMA,作為 teacher model ?如下:
這種情況下,損失的計算是有監督和無監督損失的總和:
2.4 Unsupervised Data Augmentation
論文標題:
Unsupervised Data Augmentation for Consistency Training
論文鏈接:
https://arxiv.org/pdf/1904.12848v2.pdf
代碼鏈接:
https://github.com/google-research/uda
之前的工作中對未標記的數據加入噪聲增強的方式主要是采用簡單的隨機噪聲,但是這篇文章發現對輸入 x 增加的噪聲 α 對模型的性能提升有著重要的影響,因此 UDA 提出對未標記的數據采取更多樣化更真實的數據增強方式,并且對未標記的數據上優化相同的平滑度或一致性增強目標。訓練過程如下圖所示:
(1)最小化未標記數據和增強未標記數據上預測分布之間的 KL 差異:
公式中 x 是原始未標記數據的輸入, 是對未標簽數據進行增強(如:圖像上進行 AutoAugmen,文本進行反翻譯)后的數據。
(2)為了同時使用有標記的數據和未標記的數據,添加了標記數據的 Supervised Cross-entropy Loss 和上式中定義的一致性/平滑性目標 Unsupervised Consistency Loss,權重因子 λ 為我們的訓練目標,最終目標的一致性損失函數定義如下:
UDA 證明了針對性的數據增強效果明顯優于無針對性的數據增強,這一點和監督學習的 AutoAugment、RandAugment 的結論是一致的。
2.5 小節
一致性正則化這類方法的主要思想是:對于無標簽圖像,添加噪聲之后模型預測也應該保持不變。除了以上的方法外,還有 VAT [1]、ICT [2] 等等方法,這些方法也都是找到一種更適合的數據增強,因為數據增強不應該是一成不變的,而是如 UDA 所述不同的任務其數據擴增應該要不一樣。
Proxy-label Methods
代理標簽方法是使用預測模型或它的某些變體生成一些代理標簽,這些代理標簽和有標記的數據混合一起,提供一些額外的訓練信息,即使生成標簽通常包含嘈雜,不能反映實際情況。
這類方法主要可分為分為兩類:self-training(模型本身生成代理標簽)和 multi-view learning(代理標簽是由根據不同數據視圖訓練的模型生成的)。
3.1 Self-training
如上圖所示,Self-training 的訓練過程如下:
Step1:首先,用少量的標簽數據 L 訓練 Model;也就是上圖的虛線以上部分;
Step2:然后,使用訓練后的 Model 給未標記的數據點 x∈U 分配 Pseudo-label(偽標簽);
最受歡迎的兩種方式是銳化方法和 Argmax 方法。前者在保持預測值分布的同時使分布有些極端;后者僅使用對預測具有最高置信度的預測標簽進行標記。如下所示:
另一方面:我們還可以對無標簽數據進行過濾,如果預測結果大于預定閾值 τ,再將其添加訓練中。
Setp3:通過交叉熵損失計算模型預測和偽標簽的損失。
Step4:最后,使用訓練好的模型為 U 的其余部分生成代理標簽,一直循環,直到模型無法生成代理標簽為止。
以下就是 Self-training 的偽代碼:
而 Pseudo-label [5] 與 Self-traing 基本思想是一致的,但這類方法主要缺點是:模型無法糾正自己的錯誤。如果模型對自己預測的結果很有“自信”,但這種自信是盲目的,那么結果就是錯的,這種偏差就會在訓練中得到放大。
3.2 Multi-view training
Multi-view training 利用了在實際應用中非常普遍的多視圖數據。多視圖數據可以通過不同的測量方法(例如顏色信息和紋理)收集不同的視圖圖片信息,或通過創建原始數據的有限視圖來實現。
在這種情況下,MVL 的目標是學習獨特的預測函數 fθi 為數據點 x 的給定視圖 vi(x) 建模,并共同優化所有用于提高泛化性能的功能。理想情況下,可能的觀點相互補充以便所生產的模型可以相互協作以提高彼此的性能。
3.2.1 Co-training
Co-training [3] 有 m1 和 m2 兩個模型,它們分別在不同的數據集上訓練。每輪迭代中,如果兩個模型里的一個模型,比如模型 m1 認為自己對樣本 x 的分類是可信的,置信度高,分類概率大于閾值 τ ,那 m1 會為它生成偽標簽,然后把它放入 m2 的訓練集。
簡而言之,一個模型會為另一個模型的輸入提供標簽。以下是它的偽代碼:
3.2.2 Tri-Training
Tri-training [4] 首先對有標記示例集進行可重復取樣(bootstrap sampling)以獲得三個有標記訓練集,然后從每個訓練集產生一個分類器。
在協同訓練過程中,各分類器所獲得的新標記示例都由其余兩個分類器協作提供,具體來說,如果兩個分類器對同一個未標記示例的預測相同,則該示例就被認為具有較高的標記置信度,并在標記后被加入第三個分類器的有標記訓練集。偽代碼如下:
Holistic Methods
Holistic Methods 試圖在一個框架中整合當前的 SSL 的主要方法,從而獲得更好的性能。
4.1 MixMatch【NeurIPS 2019】
MixMatch 整合了前面提到的一些 ideas 。對于給定一批有標簽的 X 和同樣大小未標簽的 U,先生成一批經過處理的增強標簽數據 X' 和一批偽標簽的 U',然后分別計算帶標簽數據和未標簽數據的損失項。表示為:
對于 alpha,這是一個與 Mixup 操作相關的參數,建議從 0.75 開始,并根據數據集進行調整。
具體操作如下:
Setp 1:Data Augmentation
與許多 SSL 方法中的典型方法一樣,我們對標記的和未標記的數據都使用數據增強。數據增強只是標準的裁剪和翻轉。
Step 2:Label Guessing
對于的每個未標記的訓練數據,MixMatch 使用模型的預測為樣本的生成一個“guess”標簽,這個“guess”標簽被用于無監督損失計算。具體地,我們計算了該模型預測的分類分布在所有 K 個增量上的平均值。如下:
每個未標記的輸入數據只增加兩次擴增(K=2):
Step 3:Sharpening
Sharpening 是一個很重要的過程,這個思想相當于深度學習中的 relu 過程。在給定預測的平均值的基礎上,應用銳化函數減小了標簽分布的熵。如下:
Sharpen 函數實際上只是一個“溫度調整”,建議將溫度參數 T 保持為 0.5。
Step 4:MixUp
與過去使用 MixUp 工作不同,將標記的示例與未標記的示例“混合”在一起,并發現提升了性能。具體地,將有標簽數據 X 和無標簽數據 U 混合在一起形成一個混合數據 W。
然后有標簽數據 X 和 W 中的前 #X 個進行 mixup 后,得到的數據作為有標簽數據,作為 label group,記為 X',同樣,無標簽數據 U 和 W 中的后 #U 個進行 mixup 后,得到的數據作為無標簽數據,作為 unlabel group,記為 U'。
Loss function:對于有標簽的數據,使用交叉熵;“guess”標簽的數據使用MSE;然后將兩者加權組合,如下圖所示。
4.2 FixMatch
FixMatch 是 Google Brain 提出的一種 Holistic 的半監督學習方法,與以往的Holistic Methods不同的是,FixMatch 使用交叉熵將 weakly augment 和 strong augment 的無標簽數據進行比較,并取得了不錯的效果。其巧妙之處是:一致性正則化使用的是交叉熵損失函數。
FixMatch 是對弱增強圖像與強增強圖像之間的進行一致性正則化,但是其沒有使用兩種圖像的概率分布一致,而是使用弱增強的數據制作了偽標簽,這樣就自然需要使用交叉熵進行一致性正則化了。此外,FixMatch 僅使用具有高置信度的未標記數據參與訓練。
增強
弱增強:用標準的翻轉和平移策略。
強增強:輸出嚴重失真的輸入圖像,先使用 RandAugment 或 CTAugment,再使用 CutOut 增強。
模型
FixMatch使用 Wide-Resnet 變體作為基礎體系結構,記為 Wide-Resnet-28-2,其深度為 28,擴展因子為 2。因此,此模型的寬度是 ResNet 的兩倍。
訓練
訓練過程如下:
Input:準備了 batch=B 的有標簽數據和 batch=μB 的無標簽數據,其中 μ 是無標簽數據的比例;
監督訓練:對于在標注數據的監督訓練,將常規的交叉熵損失 H() 用于分類任務。有標簽數據的損失記為 ls,如偽代碼中第 2 行所示;
生成偽標簽:對無標簽數據分別應用弱增強和強增強得到增強后的圖像,再送給模型得到預測值,并將弱增強對應的預測值通過 argmax 獲得偽標簽;
一致性正則化:將強增強對應的預測值與弱增強對應的偽標簽進行交叉熵損失計算,未標注數據的損失由 lu 表示,如偽代碼中的第 7 行所示;式 τ 表示偽標簽的閾值;
完整損失函數:最后,我們將 ls 和 lu 損失相結合,如偽代碼第 8 行所示,對其進行優化以改進模型,其中,λu 是未標記數據對應損失的權重。
總結
當標注的數據較少時模型訓練容易出現過擬合,一致性正則化方法通過鼓勵無標簽數據擾動前后的預測相同使學習的決策邊界位于低密度區域,很好緩解了過擬合這一現象,代理標簽法通過對未標記數據制作偽標簽然后加入訓練,以得到更好的決策邊界,而眾多方法中,混合方法表現出了良好的性能,是近來的研究熱點。
參考文獻
[1] Takeru M , Shin-Ichi M , Shin I , et al. Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018:1-1.
[2] Verma V , Lamb A , Kannala J , et al. Interpolation Consistency Training for Semi-Supervised Learning[J]. 2019.
[3] Avrim Blum and Tom Mitchell. Combining labeled and unlabeled data with co-training. In Proceedings of the eleventh annual conference on Computational learning theory, pages 92–100, 1998.
[4] Zhi-Hua Zhou and Ming Li. Tri-training: Exploiting unlabeled data using three classififiers. IEEE Transactions on knowledge and Data Engineering, 17(11):1529–1541, 2005.
[5] Dong-Hyun Lee. Pseudo-label: The simple and effiffifficient semi-supervised learning method for deep neural networks. In Workshop on challenges in representation learning, ICML, volume 3, page 2, 2013.
更多閱讀
#投 稿?通 道#
?讓你的論文被更多人看到?
如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。?
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得或技術干貨。我們的目的只有一個,讓知識真正流動起來。
?????來稿標準:
? 稿件確系個人原創作品,來稿需注明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)?
? 如果文章并非首發,請在投稿時提醒并附上所有已發布鏈接?
? PaperWeekly 默認每篇文章都是首發,均會添加“原創”標志
?????投稿郵箱:
? 投稿郵箱:hr@paperweekly.site?
? 所有文章配圖,請單獨在附件中發送?
? 請留下即時聯系方式(微信或手機),以便我們在編輯發布時和作者溝通
????
現在,在「知乎」也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」訂閱我們的專欄吧
關于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號后臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結
以上是生活随笔為你收集整理的长文总结半监督学习(Semi-Supervised Learning)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 新版五元人民币长什么样 新旧对比很明显
- 下一篇: 期货反手是什么意思啊