pytorch 之 有关交叉熵函数使用的几点说明
生活随笔
收集整理的這篇文章主要介紹了
pytorch 之 有关交叉熵函数使用的几点说明
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
1.函數原型:loss_func = nn.CrossEntropyLoss()
? ? ? ? ? ? ? ? ? ? ?loss = loss_func(pre_label, label)
2.值得注意的點,這里的label不需要賦值one-hot編碼類型,因為函數內部會自動將label變換為one-hot類型,如果這里賦值為one-hot編碼,則會產生類似如下報錯:
①:RuntimeError: multi-target not supported at /opt/conda/conda-bld/pytorch_1549635019666/work/aten/src
解決辦法:使用數值標簽,而非one-hot編碼
②:expected type torch.cuda.DoubleTensor but got torch.cuda.FloatTensor
解決辦法:針對這樣的錯誤,我們之前提到過,基本原因就是函數所需要的參數類型和我們賦值類型不同,這里介紹一種改變torch中tensor的類型的函數:
data = data.type(torch.FloatTensor)
data = data.type(torch.LongTensor)
data = data.type(torch.FloatTensor)
總結
以上是生活随笔為你收集整理的pytorch 之 有关交叉熵函数使用的几点说明的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 记录 之 Argparse 中的 可选参
- 下一篇: pytorch 之 torch.max(