label smoothing(标签平滑)
label smoothing是一種在分類問題中,防止過擬合的方法。
label smoothing(標簽平滑)
- 交叉熵損失函數在多分類任務中存在的問題
- label smoothing(標簽平滑)
- 參考資料
交叉熵損失函數在多分類任務中存在的問題
多分類任務中,神經網絡會輸出一個當前數據對應于各個類別的置信度分數,將這些分數通過softmax進行歸一化處理,最終會得到當前數據屬于每個類別的概率。
qi=exp(zi)∑j=1kexp(zj)q_i={{exp(z_i)}\over{\sum_{j=1}^kexp(z_j)}}qi?=∑j=1k?exp(zj?)exp(zi?)?
然后計算交叉熵損失函數:
Loss=?∑i=1kpilogqiLoss=-\sum_{i=1}^k p_i \space log\space q_iLoss=?i=1∑k?pi??log?qi?
pi={1,if(i=y)0,if(i≠y)p_i=\left\{\begin{matrix} 1,if(i=y)\\0,if(i\neq y) \end{matrix}\right.pi?={1,if(i=y)0,if(i?=y)?
其中i表示多分類中的某一類其中i表示多分類中的某一類其中i表示多分類中的某一類
訓練神經網絡時,最小化預測概率和標簽真實概率之間的交叉熵,從而得到最優的預測概率分布。最優的預測概率分布是:
Zi={+∞,if(i=y)0,if(i≠y)Z_i=\left\{\begin{matrix} +\infty,if(i=y)\\0,if(i\neq y) \end{matrix}\right.Zi?={+∞,if(i=y)0,if(i?=y)?
神經網絡會促使自身往正確標簽和錯誤標簽差值最大的方向學習,在訓練數據較少,不足以表征所有的樣本特征的情況下,會導致網絡過擬合。
label smoothing(標簽平滑)
label smoothing可以解決上述問題,這是一種正則化策略,主要通過soft one-hot來加入噪聲,減少真實樣本標簽的類別在計算損失函數時的權重,最終起到抑制過擬合的效果。
增加label smoothing后真實的概率分布有如下改變:
pi={1,if(i=y)0,if(i≠y)p_i=\left\{\begin{matrix} 1,if(i=y)\\0,if(i\neq y) \end{matrix}\right.pi?={1,if(i=y)0,if(i?=y)?
pi={(1??),if(i=y)?K?1,if(i≠y)p_i=\left\{\begin{matrix} (1-\epsilon),if(i=y)\\{{\epsilon}\over{K-1}},if(i\neq y) \end{matrix}\right.pi?={(1??),if(i=y)K?1??,if(i?=y)?
K表示多分類的類別總數K表示多分類的類別總數K表示多分類的類別總數
?是一個較小的超參數\epsilon是一個較小的超參數?是一個較小的超參數
交叉熵損失函數的改變如下:
Loss=?∑i=1kpilogqiLoss=-\sum_{i=1}^k p_i \space log\space q_iLoss=?i=1∑k?pi??log?qi?
Loss={(1??)?Loss,if(i=y)??Loss,if(i≠y)Loss=\left\{\begin{matrix} (1-\epsilon)*Loss,if(i=y)\\ \epsilon*Loss,if(i\neq y) \end{matrix}\right.Loss={(1??)?Loss,if(i=y)??Loss,if(i?=y)?
最優預測概率分布如下:
Zi={+∞,if(i=y)0,if(i≠y)Z_i=\left\{\begin{matrix} +\infty,if(i=y)\\0,if(i\neq y) \end{matrix}\right.Zi?={+∞,if(i=y)0,if(i?=y)?
Zi={log(k?1)(1??)?+α,if(i=y)α,if(i≠y)Z_i=\left\{\begin{matrix} log{{(k-1)(1-\epsilon)}\over{\epsilon+\alpha}},if(i=y)\\\alpha,if(i\neq y) \end{matrix}\right.Zi?={log?+α(k?1)(1??)?,if(i=y)α,if(i?=y)?
這里的α是任意實數,最終模型通過抑制正負樣本輸出差值,使得網絡有更強的泛化能力。
參考資料
總結
以上是生活随笔為你收集整理的label smoothing(标签平滑)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 螺旋矩阵II
- 下一篇: PyTorch项目使用Tensorboa