一文搞懂F.binary_cross_entropy以及weight参数
相信有很多人在用pytorch做深度學(xué)習(xí)的時(shí)候,可能只是知道模型中用的是F.binary_cross_entropy或者F.cross_entropy,但是從來(lái)沒有想過(guò)這兩者的區(qū)別,即使知道這兩者是分別在什么情況下使用的,也沒有想過(guò)它們?cè)趐ytorch中是如何具體實(shí)現(xiàn)的。在另一篇文章中介紹了F.cross_entropy()的具體實(shí)現(xiàn),所以本文將介紹F.binary_cross_entropy的具體實(shí)現(xiàn)。
當(dāng)你分別了解了它們?cè)趐ytorch中的具體實(shí)現(xiàn),也就自然知道它們的區(qū)別以及應(yīng)用場(chǎng)景了。
1、pytorch對(duì)BCELoss的官方解釋
在自己實(shí)現(xiàn)F.binary_cross_entropy之前,我們首先得看一下pytorch的官方實(shí)現(xiàn),下面是pytorch官方對(duì)BCELoss類的描述:
在目標(biāo)和輸出之間創(chuàng)建一個(gè)衡量二進(jìn)制交叉熵的標(biāo)準(zhǔn)。the unreduced loss(如:reduction屬性被設(shè)置為none) 的數(shù)學(xué)表達(dá)式為:
ln=?wn[yn?log?xn+(1?yn)?log?(1?xn)]\quad l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] ln?=?wn?[yn??logxn?+(1?yn?)?log(1?xn?)]
其中,N表示batch size,如果reduction is not none(reduction的默認(rèn)是‘mean’)時(shí)的表達(dá)式為:
?(x,y)={mean?(L),if?reduction=’mean’;sum?(L),if?reduction=’sum’.\ell(x, y) = \begin{cases} \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} \end{cases} ?(x,y)={mean(L),sum(L),?if?reduction=’mean’;if?reduction=’sum’.?
補(bǔ)充:targets也就是表達(dá)式中的y應(yīng)該是0-1之間的數(shù),Xn不能為0或1,如果Xn是0或者1,也就意味著log(Xn)或者log(1-Xn)中的一項(xiàng)沒有意義,pytorch中對(duì)log(0)作出的定義如下,也是數(shù)學(xué)上對(duì)log(0)的定義:
log?(0)=?∞,lim?x→0log?(x)=?∞\log (0) = -\infty,\lim_{x\to 0} \log (x) = -\infty log(0)=?∞,x→0lim?log(x)=?∞
然而,由于一些原因,無(wú)窮項(xiàng)在在損失函數(shù)中無(wú)法表述。舉個(gè)例子:如果Yn=0或者1-Yn=0,我們就會(huì)用0乘上無(wú)窮。而且如果我們有一個(gè)無(wú)窮的損失值,我們?cè)谟?jì)算梯度的時(shí)候也會(huì)是一個(gè)無(wú)窮,也是因?yàn)閿?shù)學(xué)上的定義:
lim?x→0ddxlog?(x)=∞\lim_{x\to 0} \fracze8trgl8bvbq{dx} \log (x) = \infty x→0lim?dxd?log(x)=∞
而且會(huì)導(dǎo)致BECLoss的反向傳播方法非線性。對(duì)于上述可能會(huì)出現(xiàn)的問(wèn)題,pytorch官方給出的解決方案是限制log函數(shù)的輸出大于等于-100,這樣的話就可以得到一個(gè)有限的損失值,以及線性的反向傳播方法。下面寫個(gè)代碼測(cè)試一下pytorch限制log函數(shù)輸出的機(jī)制:
首先,我們?nèi)∫粋€(gè)數(shù)讓其log運(yùn)算后的值小于-100,發(fā)現(xiàn)F.binary_cross_entropy中的計(jì)算結(jié)果為100,而torch.log()的計(jì)算結(jié)果為負(fù)無(wú)窮,原因在于pytorch官方實(shí)現(xiàn)的F.binary_cross_entropy對(duì)log輸出做了限制。大家不要對(duì)100感到疑惑呀,為什么不是-100,那是因?yàn)閾p失函數(shù)計(jì)算的時(shí)候前面有個(gè)負(fù)號(hào)。
2、pytorch的官方實(shí)現(xiàn)
input的維度(N,*),其中*表示可以是任何維度。target和input的維度需一致。OK,其實(shí)最關(guān)鍵的還是上面的數(shù)學(xué)表達(dá)式,知道了表達(dá)式也就可以簡(jiǎn)單實(shí)現(xiàn)二值交叉熵了。
輸出:
# input tensor([[[0.7266, 0.9478, 0.3987],[0.4134, 0.1654, 0.0298],[0.1266, 0.1153, 0.0549]]]) # target tensor([[[0., 1., 1.],[1., 0., 0.],[0., 0., 0.]]]) # output tensor(0.6877)3、根據(jù)公式自己實(shí)現(xiàn)
class binary_ce_loss(torch.nn.Module):def __init__(self):super(binary_ce_loss, self).__init__()def forward(self, input, target):input = input.view(input.shape[0], -1)target = target.view(target.shape[0], -1)loss = 0.0for i in range(input.shape[0]):for j in range(input.shape[1]):loss += -(target[i][j] * torch.log(input[i][j]) + (1 - target[i][j]) * torch.log(1 - input[i][j]))return loss/(input.shape[0]*input.shape[1]) # 默認(rèn)取均值input和target的維度需相同,上述的例子中,它們的維度均是[1,3,3],我們可以把1看作batchsize的大小,3*3看作是圖片的大小。首先將shape變成[1,3*3],然后按照公式計(jì)算每一個(gè)batchsize的損失,再求和,最后按照pytorch官方默認(rèn)的方式求平均,即可大功告成。
4、weight參數(shù)含義
在寫代碼的過(guò)程中,我們會(huì)發(fā)現(xiàn)F.binary_cross_entropy中還有一個(gè)參數(shù)weight,它的默認(rèn)值是None,估計(jì)很多人不知道weight參數(shù)怎么作用的,下面簡(jiǎn)單的分析一下:
首先,看一下pytorch官方對(duì)weight給出的解釋,if provided it’s repeated to match input tensor shape,就是給出weight參數(shù)后,會(huì)將其shape和input的shape相匹配。回憶公式:
ln=?wn[yn?log?xn+(1?yn)?log?(1?xn)]\quad l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] ln?=?wn?[yn??logxn?+(1?yn?)?log(1?xn?)]
默認(rèn)情況,也就是weight=None時(shí),上述公式中的Wn=1;當(dāng)weight!=None時(shí),也就意味著我們需要為每一個(gè)樣本賦予權(quán)重Wi,這樣weight的shape和input一致就很好理解了。
首先看pytorch中weight參數(shù)作用后的結(jié)果:
通過(guò)下面的代碼再次驗(yàn)證weight是如何作用的,weight就是為每一個(gè)樣本加權(quán)。
class binary_ce_loss(torch.nn.Module):def __init__(self):super(binary_ce_loss, self).__init__()def forward(self, input, target, weight=None):input = input.view(input.shape[0], -1)target = target.view(target.shape[0], -1)loss = 0.0for i in range(input.shape[0]):for j in range(input.shape[1]):loss += -weight[i][j] * (target[i][j] * torch.log(input[i][j]) + (1 - target[i][j]) * torch.log(1 - input[i][j]))return loss/(input.shape[0]*input.shape[1]) # 默認(rèn)取均值 myloss = binary_ce_loss() print(myloss(input, target, weight=weight)) """ # myloss tensor(0.4621) """pytorch官方的代碼和自己實(shí)現(xiàn)的計(jì)算出的損失一致,再次說(shuō)明binary_cross_entropy的weight權(quán)重會(huì)分別對(duì)應(yīng)的作用在每一個(gè)樣本上。
5、總結(jié)
看源碼是最直接有效的手段。 留個(gè)彩蛋,下篇文章講balanced_cross_entropy,解決樣本之間的不平衡問(wèn)題。
注:如有錯(cuò)誤還請(qǐng)指出!
總結(jié)
以上是生活随笔為你收集整理的一文搞懂F.binary_cross_entropy以及weight参数的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: QQ等级计算方法及图标
- 下一篇: 初识手机阅读行业