NLLLoss CrossEntropyLoss Pytorch
NLLLoss
在圖片單標簽分類時,輸入m張圖片,輸出一個m*N的Tensor,其中N是分類個數。比如輸入3張圖片,分三類,最后的輸出是一個3*3的Tensor,舉個例子:
第123行分別是第123張圖片的結果,假設第123列分別是貓、狗和豬的分類得分。
可以看出模型認為第123張都更可能是貓。
然后對每一行使用Softmax,這樣可以得到每張圖片的概率分布。
這里dim的意思是計算Softmax的維度,這里設置dim=1,可以看到每一行的加和為1。比如第一行0.6600+0.0570+0.2830=1。
如果設置dim=0,就是一列的和為1。比如第一列0.2212+0.3050+0.4738=1。
我們這里一張圖片是一行,所以dim應該設置為1。
然后對Softmax的結果取自然對數:
Softmax后的數值都在0~1之間,所以ln之后值域是負無窮到0。
NLLLoss的結果就是把上面的輸出與Label對應的那個值拿出來,再去掉負號,再求均值。
假設我們現在Target是[0,2,1](第一張圖片是貓,第二張是豬,第三張是狗)。第一行取第0個元素,第二行取第2個,第三行取第1個,去掉負號,結果是:[0.4155,1.0945,1.5285]。再求個均值,結果是:
下面使用NLLLoss函數驗證一下:
嘻嘻,果然是1.0128!
CrossEntropyLoss
CrossEntropyLoss就是把以上Softmax–Log–NLLLoss合并成一步,我們用剛剛隨機出來的input直接驗證一下結果是不是1.0128:
總結
以上是生活随笔為你收集整理的NLLLoss CrossEntropyLoss Pytorch的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: detach detach_ pyto
- 下一篇: array.array python y