学习报告:基于原型网络的小样本学习《Prototypical Networks for Few-shot Learning》
學習報告:基于原型網絡的小樣本學習《Prototypical Networks for Few-shot Learning》
- 一、概述
- 二、方法解析
- 三、實驗
- 3.1 說明
- 3.2 Omniglot分類
- 3.3 miniImageNet分類
- 四、總結分析
本篇學習報告基于論文《Prototypical Networks for Few-shot Learning》,該論文的主要貢獻有兩點:(1)對圖像領域的Few-Shot/Zero-Shot(小樣本/零樣本)任務,應用設計簡單的原型網絡方法(見第二部分),在通用數據集上達到了較好的實驗效果(見第三部分);(2)對原型網絡本身進行了較為深入的分析,且分析了距離度量方式的選擇對任務效果的影響(見圖3)。
原文鏈接及開源代碼已置于文末。
一、概述
在小樣本分類問題中,最需要解決的一個問題是數據的過擬合問題。由于訓練數據過少,一般的分類算法會表現出過擬合的現象,從而導致分類結果與實際結果有較大的誤差。為了減少因數據量過少而導致的過擬合的影響,可以使用基于度量的元學習方法,該論文所提出的原型網絡便屬于這種方法。
該論文為解決小樣本分類問題提出了原型網絡。在訓練集中,對于每一種出現的類別,只給出少量樣本,但分類器能夠很好的泛化到其他沒有出現于訓練集中的新類別。原型網絡會學習一個度量空間,在該空間中,可以通過計算與每個類的對應原型表示的距離來進行分類,距離哪個類的原型表示最近,則被判斷為哪個類。與最近的小樣本學習方法相比,該方法反映了一種更簡單的歸納偏差,有利于在這種有限的數據范圍內使用,并取得優異的效果。論文表明一些簡單的設計決策比最近涉及復雜體系結構選擇和元學習的方法可以產生較好的改進效果。
介紹兩類常見的Few-Shot方法:
匹配網絡(Matching Network):
可以理解為在embedding空間中的加權最近鄰分類器。模型在訓練過程中通過對類標簽和樣本的二次采樣來模仿Few-Shot任務的測試場景,學習一個匹配網絡。該網絡只在訓練集中的關系基礎上訓練,并且直接應用于測試集中的關系。原型網絡也屬于一種匹配網絡。實驗和總結中將對原型網絡和匹配網絡的不同之處和分類效果進行比較。
Optimization-based meta-learning:
這種方法在訓練的過程中的目標是學習如何通過少量樣本更好的擬合數據,因此該類方法會針對測試數據集對網絡進行調整。例如,在訓練過程中,利用LSTM的網絡結構學習每個訓練step所需要的學習率。
二、方法解析
在該論文所提出的原型網絡方法中,需要將樣本投影到一個度量空間,且在這個空間中同類樣本距離較近,異類樣本的距離較遠。下圖為這個投影空間的示意圖,假如在這個投影空間中,存在三個類別的樣本,且相同類別的樣本間距離較近。為了給一個未標注樣本x進行標注,則將樣本x投影至這個空間并計算x與各個類別的原型距離,離得近的就認為x屬于哪個類別。
圖1 投影空間示意圖那么,現在有幾個問題:
1、怎么將這些樣本投影至一個空間且讓同類樣本間距離較近?
2、怎么說明一個類別所在的位置,從而能夠讓未標記的樣本計算與類別的距離?
如何將樣本投影至一個空間且讓同類樣本間距離較近?論文中使用的是一個帶參數φ的嵌入函數fφ(x),這個函數可以理解為投影的過程,x表示樣本的特征向量,函數值表示投影到那個空間后的值,這個嵌入函數fφ(x)是一個神經網絡,參數φ是需要學習的,可以認為參數φ決定了樣本間的位置,所以需要學習到一個較好的φ值,讓同類別樣本間距離較近。
此外,還需要考慮如何說明一個類別所在位置,論文中認為一個類的位置由這個類所有樣本在投影空間里的平均值決定,類k的原型表示公式如下:
其中Sk表示類k,|Sk|表示類k中樣本的數量,(xi , yi)為樣本的特征向量和標記,此公式實際上為一個求平均的過程。
得到每個類的原型后,就需要根據樣本與各個類的原型的距離,求一個樣本屬于一個類的概率。因為在訓練時這個樣本是已標記的,即我們已知類k的原型,已知一個屬于類k的樣本,求此樣本屬于類k的概率,因此我們的目標函數就是求這個概率的最大值。
此公式所表示的意義是,對于樣本x,求它到每個類的距離,然后進行歸一化操作得到概率,即x屬于類k的概率。其中d為距離函數,在本篇論文中使用的是歐幾里得距離。在訓練過程中,x的標簽是已知的。論文中的目標函數為:
一般通過隨機梯度下降方法來求它的最小值,從而收斂后學到一個好的φ值。可以認為,訓練結束后此投影函數可以將同類的樣本投影到一個相互距離較近的地方。
字符說明:
N:訓練集中樣例的數量
K:訓練集中類的數量
NC:每個Episode中類別的數量
NS:每個類中支持樣例的數量
NQ:每個類中查詢樣例的數量
以下Algorithm 1給出了計算訓練集損失J(Φ)的偽代碼
計算過程:為Episode選擇類別 → 選擇支持集 → 選擇訓練集 → 計算支持集的原型 → 初始化損失 → 更新損失
在測試過程中,使用與訓練過程中相同的投影函數方法,求每個類的原型,根據一個未標記的樣本x,求屬于每個類的概率,認為概率值大的那個,即為x屬于的類別。
總結原型網絡的基本思想:基于集群,找到類的原型,找到合適距離度量方式進行分類。
三、實驗
3.1 說明
實驗的數據分為支持集和查詢集:
支持集:即訓練集,在該論文中由一些已標記的樣本組成,比如有N個類,每個類中有M個樣本,則為N-way–M-shot。
查詢集:即測試集,在該論文中由一些已標記的樣本和部分未標記的樣本組成,后續實驗結果表明訓練集的way大于測試集的話分類結果更好(我認為這有助于提高模型的泛化性),而shot最好一致(我認為是為了保持不同類別樣本的平衡性)。
3.2 Omniglot分類
Omniglot是一個1623個手寫字符分類的數據集。每一個字符類別只有20個樣本,不同樣本由不同的人繪制。
該論文使用原形網絡在Omniglot數據集上進行實驗,使用歐幾里得距離作為距離度量,分別在1-shot和5-shot進行實驗。下圖為某個子集的度量空間的可視化,其中黑色點代表每種類別的原形,紅色代表被錯誤分類的數據,紅色箭頭的指向為真實的類別。
圖2 Omniglot數據集中某個子集的度量空間的t-SNE可視化圖訓練episode的設置為60個類別和每個類別有5個query查詢點。實驗結果發現在訓練和測試時保持相同的樣本數據量(即shot相同)和episode使用更多的類別(即way更大)會使得實驗效果更好。下表展示的是該論文所提出的方法與其他方法在Omniglot數據集上的結果對比。
表1 Omniglot數據集分類結果比較3.3 miniImageNet分類
minilmageNet數據集包含100個類別,每個類別中包含600個樣本數據。其中64個類別數據作為訓練集,16個類別數據作為驗證集,20個類別數據作為測試集。
表2 miniImageNet數據集分類結果比較實驗分別對1-shot和5-shot的設置進行訓練episode為5-way和20-way的訓練,實驗結果表明也訓練episode中設置更多的類別,對實驗的結果有一定的增益效果,這是因為更大的way設置有助于網絡進行更好的泛化,使得模型在度量空間做出更細粒度的決策。
還有個比較有意思的實驗結果:在N-way M-shot問題中的M=1,也就是one-shot的情況下,prototype network實際上等價于matching network;此外,無論是one-shot還是M-shot(M>1),歐氏距離(Euclid.)的效果都要比余弦距離(Cosine)的效果好(如下圖所示),因此本文使用的距離計算公式為歐氏距離。
四、總結分析
本論文提出的Prototypical Networks(P-net)思想與Matching Networks(M-net)十分相似,兩種網絡主要有以下不同點:1.使用了不同的距離度量方式,M-net中是余弦距離,P-net中使用的是屬于布雷格曼散度的歐幾里得距離。2.二者在few-shot的場景下不同,在one-shot時等價(one-shot時取得的原型就是支持集中的樣本,相當于不用進行平均處理)3.網絡結構上,P-net將編碼層和分類層合一,參數更少,訓練更加方便。論文的實驗部分中也在不同數據集上進行了兩種網絡的效果比較,結果顯示P-net的效果要優于M-net。本論文提出的原型網絡方法雖然結構設計比較簡單,但是卻能達到很好的效果,這為我們在解決小樣本分類問題時提供了一種可行的解決思路。
論文地址:https://arxiv.org/pdf/1703.05175.pdf
源代碼:https://github.com/jakesnell/prototypical-networks
總結
以上是生活随笔為你收集整理的学习报告:基于原型网络的小样本学习《Prototypical Networks for Few-shot Learning》的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: java 线程 设计模式_Java多线程
- 下一篇: allegro如何编辑铜皮