【论文学习】ICLR2021,鲁棒早期学习法:抑制记忆噪声标签ROBUST EARLY-LEARNING: HINDERING THE MEMORIZATION OF NOISY LABELS
??論文來自ICLR2021,作者是悉尼大學的Xiaobo Xia博士。論文基于早停和彩票假說,提出了一種處理標簽噪聲問題的新方法。我就論文要點學習整理,目前還沒有找到開源代碼,我實現了一份在本文中給出。我對論文中部分試驗復現,并補充進行一些新試驗。
??論文鏈接
文章目錄
- 一、理論要點
- 二、公式推導
- 三、效果對比
- 四、我的代碼及部分試驗復現
- 1,核心代碼
- 2,我的試驗
- 2.1,不同噪聲率下觀察“早停”的作用
- 2.2,不同τ\tauτ參數下觀察“彩票假說”現象
- 2.3,不同噪聲率和不同τ\tauτ參數下觀察本文算法去噪效果
- 2.4,算法局部修改試驗
- 2.4.1 (1?τ1-\tau1?τ)
- 2.4.2 L1正則
- 2.4.3 gig_{i}gi?
- 五、讀后感
一、理論要點
這篇文章基于兩點主要理論:一是深度網絡會先記憶標簽清晰的訓練數據,然后記憶標簽有噪聲的訓練數據。因此,用早停法學習可抑制噪聲標簽。二是彩票假說指出深度網絡中只有部分參數對模型起作用,本文因此認為只有部分參數對擬合干凈標簽有用,稱之為關鍵參數,而其他參數則傾向于擬合噪聲標簽,稱之為非關鍵參數。在每次迭代中,對不同的參數執行不同的更新規則以逐漸使非關鍵參數歸零,以此抑制噪聲標簽發揮作用。二、公式推導
文中總共有以下6個公式:
min L(W;S)L(\mathcal{W};S)L(W;S) = min1n∑i=1nL(W;(xi,yi))+λ∥W∥1\frac{1}{n}\sum \limits_{i=1} ^{n}L(\mathcal{W};(x_{i},y_{i})) + \lambda\begin{Vmatrix}\mathcal{W}\end{Vmatrix}_{1}n1?i=1∑n?L(W;(xi?,yi?))+λ∥∥?W?∥∥?1? ???????(1)
W(k+1)←W(k)?η(?L(W(k);S?)?W(k)+λsgn(W(k)))\mathcal{W}(k+1)\leftarrow\mathcal{W}(k) - \eta(\frac{\partial L(\mathcal{W}(k);S^{*})}{\partial\mathcal{W}(k)}+\lambda sgn(\mathcal{W}(k)))W(k+1)←W(k)?η(?W(k)?L(W(k);S?)?+λsgn(W(k)))??????(2)
gi=∣?L(Wi;S)×Wi∣,i∈[m]g_{i}=|\nabla L(\tiny W_{i}\normalsize ;S) \times \tiny W_{i}\normalsize |, i\in[m]gi?=∣?L(Wi?;S)×Wi?∣,i∈[m]?????????????????(3)
mc=(1?τ)mm_{c}=(1-\tau)mmc?=(1?τ)m ????????????????????????(4)
Wc(k+1)←Wc(k)?η((1?τ)?L(Wc(k);S?~)?Wc(k)+λsgn(Wc(k)))\mathcal{W}_{c}(k+1)\leftarrow\mathcal{W}_{c}(k) - \eta((1-\tau)\frac{\partial L(\mathcal{W}_{c}(k);\tilde{S^{*}})}{\partial\mathcal{W}_{c}(k)}+\lambda sgn(\mathcal{W}_{c}(k)))Wc?(k+1)←Wc?(k)?η((1?τ)?Wc?(k)?L(Wc?(k);S?~)?+λsgn(Wc?(k))) (5)
Wn(k+1)←Wn(k)?ηλsgn(Wn(k))\mathcal{W}_{n}(k+1)\leftarrow\mathcal{W}_{n}(k) - \eta \lambda sgn(\mathcal{W}_{n}(k))Wn?(k+1)←Wn?(k)?ηλsgn(Wn?(k)) ??????????? (6)
考慮給損失函數加入一個l1正則項,如式(1);
根據式(1)的損失函數,使用SGD方式更新權重,如式(2);
對于任一個參數Wi∈Wm\tiny W_{i}\normalsize \in {\mathcal{W}^{m}}Wi?∈Wm,根據式(3)計算一個參考量gig_{i}gi?,根據gig_{i}gi?對W\mathcal{W}W排序。根據式(4)計算得到關鍵參數的個數為mcm_{c}mc?個,然后W\mathcal{W}W排序考前的mcm_{c}mc?個參數就是關鍵參數Wc\mathcal{W}_{c}Wc?,其余參數為非關鍵參數Wn\mathcal{W}_{n}Wn?;
對于關鍵參數按照(5)式更新,注意梯度乘上了一個衰減系數(1?τ1-\tau1?τ),作者說這是為了防止訓練過程中過度自信下降。(對此不是很理解)
對于非關鍵參數按照(6)式更新,此時把梯度置零,只保留了正則化項,這會導致這些非關鍵參數逐漸縮小直到接近于0而失去作用。
其中公式(3)比較難理解,為什么用這個指標來判斷哪些是關鍵參數呢?原文的解釋如下:
構造一個函數G(t)=L(tW;S)G(t)=L(\mathcal{tW};S)G(t)=L(tW;S),則
G′(t)=?L(tW;S)TWG'(t)=\nabla L(\mathcal{tW};S)^{T}\mathcal{W}G′(t)=?L(tW;S)TW,
令t=1t=1t=1,有:
G′(1)=?L(W;S)TW=<?L(W;S),W>G'(1)=\nabla L(\mathcal{W};S)^{T}\mathcal{W}=<\nabla L(\mathcal{W};S),\mathcal{W}>G′(1)=?L(W;S)TW=<?L(W;S),W>(<>表示內積)
滿足最優化條件時,?L(W;S)=0\nabla L(\mathcal{W};S)=0?L(W;S)=0,因此G′(1)=0G'(1)=0G′(1)=0,
由G′(1)=0G'(1)=0G′(1)=0可得到(3)式
說實話,這個部分我沒有看懂,有理解的小伙伴可以講一講。
三、效果對比
??作者指出由于本文的主要目的是提出一個新的概念,并且本文沒有使用多種綜合措施,所以效果趕不上該領域在2020年的兩個SOTA方法:DivideMix和SELF,除了這兩個之外,本文方法比其他模型的效果都好。作者進行了大量對比試驗,其中在MNIST、F-MNIST、CIFAR-10、CIFAR-100這四個數據集上的試驗如表1。
??作者隨后又在Food-101和WebVision這兩個數據集上進行了試驗,結論類似。
??作者又進行了消融試驗,試驗發現模型效果對參數τ\tauτ不敏感。
四、我的代碼及部分試驗復現
1,核心代碼
??由于沒有開源,我按照自己理解進行代碼實現。根據文中公式,該算法只涉及到參數更新過程,因此只需要在pytorch中重寫SGD即可實現本算法中說的關鍵/非關鍵參數分別更新;然后在訓練的時候加入早停即可。
??重寫的newSGD代碼如下,主要是增加了tau和decay1兩個參數。tau就是文中τ\tauτ噪聲率,注意式(6)和式(5)的區別,對于非關鍵參數,就是把梯度項置零,只有正則化項了,所以代碼可以非常簡潔的寫出來。在SGD中,weight_decay就是正則化項,但是torch1.6給出的SGD用的是l2正則,而論文中給出的公式用的是l1正則,所以我又新加了一個weight_decay1用來實現l1正則。
然后在訓練時把原來的SGD替換即可
from newSGD import newSGD optimizer = newSGD(net.parameters(), lr=0.01,momentum=0.9, tau=0.2, weight_decay1=1e-3)2,我的試驗
??為了加快速度,試驗主要在MNIST數據集和LeNet上進行,個別補充進行了CIFAR10上的ResNet18試驗。試驗參數配置:epoch = 100, BatchSize = 128, lr=0.01 ,momentum = 0.9, weight_decay = 0.001。由于L1正則不便于觀察規律(原因見2.4.2節介紹),下面試驗使用L2正則。噪聲數據只使用同步噪聲標簽,即每個類別按照噪聲率抽取樣本隨機變換為任意其他類別的標簽。注意噪聲只存在于訓練集,測試集不含噪聲,是干凈的。
2.1,不同噪聲率下觀察“早停”的作用
??神經網絡在訓練早期只學習干凈標簽,在訓練的后期才逐漸學習噪聲標簽,因此可以用早停法抑制噪聲標簽。我們先觀察這個現象,試驗中不使用本文提到的新算法,只使用LeNet和交叉熵損失:
??從圖中可以看出幾個特點:
(1)隨著噪聲率的增加,訓練集訓練精度明顯降低,但測試集仍能達到較高的精度,例如即使噪聲含量80%時,此時訓練集精度不足35%,但測試集精度最高仍可達到85%以上。這說明神經網絡本身就對噪聲有一定的魯棒性。
(2)含噪聲時,網絡早期先學習干凈數據,所以測試集仍可以達到很高精度,但后期開始記憶噪聲數據,導致測試集精度下降。所以早停肯定可以起到抑制噪聲標簽的作用。
(3)對比噪聲含量80%和90%的訓練精度曲線(圖中淺藍和深藍虛線),我們發現一個有意思的地方,90%噪聲的訓練精度后期比80%的還高。我的解釋是:由于數據集就10個類別,90%噪聲時幾乎等于完全隨機,網絡從一開始就意識到這沒有任何規律可以找,干脆就快速發展記憶數據能力了。這很有意思,值得繼續思考。
2.2,不同τ\tauτ參數下觀察“彩票假說”現象
??彩票假說指出神經網絡只有少部分參數真正發揮作用。上面newSGD算法中給出的τ\tauτ會使得網絡中每個參數張量中都有占比例為τ\tauτ的參數在經過充分訓練后趨于0,因此使用這個代碼就可以觀察到彩票假說現象。我們使用不含噪聲的數據來觀察這個現象:
從圖中可以看出,神經網絡具有驚人的參數壓縮潛力,τ=0.995\tau=0.995τ=0.995時,相當于只有0.5%的參數起作用,測試精度仍可達到95%以上。τ=0.999\tau=0.999τ=0.999時,訓練結束后,我們把其中conv2層的權重絕對值reshape到25×96以及fc1層的權重絕對值進行可視化,畫出來如下圖。可見其中確實只有極少的參數存在了,但即使這么稀疏的參數,仍然可以達到70%以上的精度。τ=0.9999\tau=0.9999τ=0.9999時,網絡的效果才有明顯的下降,但仍有接近40%的精度。
2.3,不同噪聲率和不同τ\tauτ參數下觀察本文算法去噪效果
又在CIFAR10上用ResNet18做了部分試驗,效果和上圖類似:
??從圖中可以看出:
??τ=0\tau=0τ=0就是論文Table1中的CE,使用本算法之后,τ\tauτ較大時起到的作用只是隨著訓練的繼續,測試精度下降變少,但考慮到早停時,最佳精度發生在初期,使用本方法后和CE并無明顯優勢。這可能是MNIST數據集過于簡單,加的噪聲模式也比較簡單,所以看不出論文算法的優勢。這個和論文中的Table1也是一致的。
2.4,算法局部修改試驗
??對算法中的衰減系數(1?τ1-\tau1?τ),l1正則,劃分關鍵參數的判據gig_{i}gi?等的作用和必要性仍不太理解,因此我們從試驗對比中觀察它們的效果。
2.4.1 (1?τ1-\tau1?τ)
對于式(5)中的(1?τ1-\tau1?τ)項,在原本的SGD公式中是沒有的,作者說這里增加此項能夠抑制過度自信下降的作用,下圖以20%噪聲率為例,對比了使用(1?τ1-\tau1?τ)和不使用(1?τ1-\tau1?τ)的效果。
從圖中可以看出,當τ\tauτ=0.8或0.9時,(1?τ1-\tau1?τ)項能夠起到一定的正則效果,會避免訓練的后期記憶噪聲數據,但效果并不明顯。
2.4.2 L1正則
下圖給出L1正則和L2正則在20%噪聲率時的測試集精度曲線,可以看出L1正則的正則化效果更重,即使τ\tauτ較小時也可以防止模型后期記憶噪聲數據。但是L1正則在模型初期的精度表現不如L2正則,也就是說如果使用早停的話其效果不如L2。由于L1正則過強的正則化效果,不便于觀察2.1,2.2節中的現象,所以前序試驗都使用L2正則進行。
2.4.3 gig_{i}gi?
??gig_{i}gi?是劃分關鍵和非關鍵參數的依據,作者在公式(3)中給出的計算方法是參數的梯度和參數的點積的絕對值。作者的推導過程我沒有看懂(數學太菜了!),但我可以用試驗檢驗以下這個表達式的充分必要性,也就是
- 使用式(3)能否把參數壓縮到少量關鍵參數;
- 使用式(3)確定的關鍵參數是否真的關鍵,即是否能以少量關鍵參數仍達到和全量參數接近的精度;
??文中公式(3)我在代碼中寫成 g = (d_p * p).abs(),我又嘗試了其他幾種劃分關鍵和非關鍵參數的方法,
??方法B:g = d_p.abs() + p.abs()
??方法C:提前隨機選定每個參數張量中占比τ\tauτ的位置制成mask,然后每輪參數更新時,這些位置對應的參數的梯度置0。
??我們定義絕對值大于0.001的參數為有效參數,上圖的第一行三個圖表示的是隨著訓練輪數,網絡中的總有效參數量的變化情況,第二行三個圖表示隨著訓練輪數,測試集精度的變化。
??從上面圖中對比我們可以看出,對于本文方法(最左圖),在不同的τ\tauτ下都能使有效參數量逐漸收縮到占比總參數量約為τ\tauτ的位置處,并且精度仍能夠有著不錯的保持。而對于另外兩種方法,它們不能夠保持有效參數不再壓縮,而是會出現參數量不斷的下降,精度也掉的一塌糊涂,說明這兩種方法不能有效區分關鍵參數和非關鍵參數,也就不能夠在訓練后期把關鍵參數穩定住。實際上我還嘗試了很多其他的參數劃分方法,都沒有文中方法有效。
??所以說文中式(3)給出的關鍵參數劃分判據是非常有效的,對公式的推導過程后續再慢慢吃透。
??(補充說明,第一行圖中可以明顯觀察到有效參數量每次都是在75epoch和95epoch處有明顯轉折,這個原因是網絡使用的默認的標準參數初始化方式,參數的分布概率是固定的,而同樣的weight_decay下參數的收縮速率也是固定的,所以會有同批的參數被同時收縮到0.001以下。)
五、讀后感
??本文提出的方法實際上主要是從彩票假說和神經網絡早期學習干凈標簽這兩點出發,本文方法的噪聲標簽抑制能力實際上達不到SOTA。但彩票假說中只是指出了神經網絡中真正關鍵的參數很少,卻也沒有指出有效的提取關鍵參數的方法,而本文提出的劃分關鍵參數的方法非常有意思,有可能提供一種新的模型壓縮的思路。這篇論文的寫作也非常好,值得學習。
<補充 2021-02-09>更具tau修正梯度的核心部分代碼修改如下,能夠進一步提高精度,加快運算速度。 m = p.numel()if tau != 0 and m>1000:g = (d_p * p).abs()if m>10000:gf = g.flatten()[:10000]mn = int(10000*(1-100/math.sqrt(m)*(1-tau)))if mn > 9990:mn = 9990kth,_ = gf.kthvalue(mn)else:mn = int(p.numel()*tau)kth,_ = g.flatten().kthvalue(mn)d_p = torch.where(g < kth, torch.zeros_like(d_p), d_p)
總結
以上是生活随笔為你收集整理的【论文学习】ICLR2021,鲁棒早期学习法:抑制记忆噪声标签ROBUST EARLY-LEARNING: HINDERING THE MEMORIZATION OF NOISY LABELS的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Pytorch CookBook
- 下一篇: LeNet试验(五)观察“彩票假说”现象