pytorch几种损失函数CrossEntropyLoss、NLLLoss、BCELoss、BCEWithLogitsLoss、focal_loss、heatmap_loss
分類問題常用的幾種損失,記錄下來備忘,后續(xù)不斷完善。
nn.CrossEntropyLoss()交叉熵?fù)p失
常用于多分類問題
CE = nn.CrossEntropyLoss() loss = CE(input,target)Input: (N, C) , dtype: float, N是樣本數(shù)量,在批次計(jì)算時(shí)通常就是batch_size
target: (N), dtype: long,是類別號(hào),0 ≤ targets[i] ≤ C?1
pytorch中的交叉熵?fù)p失就是softmax和NLL損失的組合,即
nn.NLLLoss()
NLL = nn.NLLLoss() loss = NLL(input,target)Input: (N, C) , dtype: float, N是樣本數(shù)量,在批次計(jì)算時(shí)通常就是batch_size
target: (N), dtype: long,是類別號(hào),0 ≤ targets[i] ≤ C?1
nn.BCELoss() 二元交叉熵?fù)p失
常用于二分類或多標(biāo)簽分類
BCE = nn.BCELoss() loss = BCE(input,target)Input: (N, x) , dtype: float, N是樣本數(shù)量,在批次計(jì)算時(shí)通常就是batch_size,x是標(biāo)簽數(shù)
target: (N, x), dtype: float,通常是標(biāo)簽的獨(dú)熱碼形式,注意需改成float格式
nn.BCEWithLogitsLoss()
相當(dāng)于BCE加上sigmoid
nn.BCEWithLogitsLoss()(input,target) == nn.BCELoss()(torch.sigmoid(input),target)focal_loss
focal loss在pytorch中沒有,它常用在目標(biāo)檢測(cè)問題中,公式和曲線見論文中的圖:
帶平衡參數(shù)的focal loss公式如下:
代碼:(待后補(bǔ))
heatmap_loss
heatmap_loss出現(xiàn)在anchor-free的目標(biāo)檢測(cè)網(wǎng)絡(luò)centernet和conernet中,它在focal loss的基礎(chǔ)上進(jìn)一步改進(jìn),加入了對(duì)熱點(diǎn)區(qū)域的損失減小的措施,以使模型輸出可以較容易的收斂到檢測(cè)點(diǎn)附件區(qū)域。(否則,必須收斂到檢測(cè)點(diǎn)的話,難度太大,收斂速度慢)
注意,它只是在otherwise情況下多加了一個(gè) (1?Yxyc)β(1-Y_{xyc})^\beta(1?Yxyc?)β 除此之外,就是focal loss
總結(jié)
以上是生活随笔為你收集整理的pytorch几种损失函数CrossEntropyLoss、NLLLoss、BCELoss、BCEWithLogitsLoss、focal_loss、heatmap_loss的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 好玩的deep dream(清晰版,py
- 下一篇: 用pytorch及numpy计算成对余弦