pytorch中x.norm(p=2,dim=1,keepdim=True)的理解
代碼:x.norm(p=2,dim=1,keepdim=True)
功能:求指定維度上的范數(shù)。
函數(shù)原型:【返回輸入張量給定維dim?上每行的p范數(shù)】
? ? ? ? ? ? ? ? ?torch.norm(input, p, dim, out=None,keepdim=False) → Tensor
? ? ? ? ?注:范數(shù)求法:【對(duì)N個(gè)數(shù)據(jù)求p范數(shù)】
? ? ? ? ? ? ? ? ??
函數(shù)參數(shù):
input (Tensor) – 輸入張量
p (float) – 范數(shù)計(jì)算中的冪指數(shù)值
dim (int) – 縮減的維度,dim=0是對(duì)0維度上的一個(gè)向量求范數(shù),返回結(jié)果數(shù)量等于其列的個(gè)數(shù),也就是說(shuō)有多少個(gè)0維度的向? ? ? ? ? ? ? ? ? ? ? ? ? 量,?將得到多少個(gè)范數(shù)。dim=1同理。
out (Tensor, optional) – 結(jié)果張量
keepdim(bool)– 保持輸出的維度 。當(dāng)keepdim=False時(shí),輸出比輸入少一個(gè)維度(就是指定的dim求范數(shù)的維度)。而? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? keepdim=True時(shí),輸出與輸入維度相同,僅僅是輸出在求范數(shù)的維度上元素個(gè)數(shù)變?yōu)?。這也是為什么有時(shí)? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 我們把參數(shù)中的dim稱為縮減的維度,因?yàn)閚orm運(yùn)算之后,此維度或者消失或者元素個(gè)數(shù)變?yōu)?。
?
例子說(shuō)明:
已知一個(gè)3×4矩陣,如下:
tensor([[ 1.,? 2.,? 3.,? 4.],
??????? [ 2.,? 4.,? 6.,? 8.],
??????? [ 3.,? 6.,? 9., 12.]])
1)dim參數(shù),分別對(duì)其行和列分別求2范數(shù):
inputs1 = torch.norm(inputs, p=2, dim=1, keepdim=True)
print(inputs1)
inputs2 = torch.norm(inputs, p=2, dim=0, keepdim=True)
print(inputs2)
結(jié)果分別為:
tensor([[ 5.4772],
??????? [10.9545],
??????? [16.4317]])
tensor([[ 3.7417,? 7.4833, 11.2250, 14.9666]])
2)keepdim參數(shù)
inputs3 = inputs.norm(p=2, dim=1, keepdim=False)
print(inputs3)
inputs3為:
tensor([ 5.4772, 10.9545, 16.4317])
?
輸出inputs1和inputs3的shape:
print(inputs1.shape)
print(inputs3.shape)
torch.Size([3, 1])
torch.Size([3])
可以看到inputs3少了一維,其實(shí)就是dim=1(求范數(shù))那一維(列)少了,因?yàn)閺?列變成1列,就是3行中求每一行的2范數(shù),就剩1列了,不保持這一維不會(huì)對(duì)數(shù)據(jù)產(chǎn)生影響。或者也可以這么理解,就是數(shù)據(jù)每個(gè)數(shù)據(jù)有沒(méi)有用[]擴(kuò)起來(lái)。
即:
keepdim = True,用[]擴(kuò)起來(lái);
keepdim = False,不用[]括起來(lái);
?
【不寫(xiě)keepdim,則默認(rèn)不保留dim的那個(gè)維度】:
inputs4 = torch.norm(inputs, p=2, dim=1)
print(inputs4)
tensor([ 5.4772, 10.9545, 16.4317])
?
【不寫(xiě)dim,則計(jì)算Tensor中所有元素的2范數(shù)】:
inputs5 = torch.norm(inputs, p=2)
print(inputs5)
tensor(20.4939)
等價(jià)于這句話:
inputs6 = inputs.pow(2).sum().sqrt()
print(inputs6)
tensor(20.4939)
總之,norm操作后dim這一維變?yōu)?或者消失。
總結(jié)
以上是生活随笔為你收集整理的pytorch中x.norm(p=2,dim=1,keepdim=True)的理解的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: css三个块元素重叠,CSS盒模型以及如
- 下一篇: 在c语言中我叫做符号变量,问渠网-C语言