Pytorch教程之torch.mm、torch.bmm、torch.matmul、masked_fill
文章目錄
- 1、簡介
- 2、torch.mm
- 3、torch.bmm
- 4、torch.matmul
- 5、masked_fill
1、簡介
這幾天正在看NLP中的注意力機制,代碼中涉及到了一些關于張量矩陣乘法和填充一些代碼,這里積累一下。主要參考了pytorch2.0的官方文檔。
①torch.mm(input,mat2,*,out=None)
②torch.bmm(input,mat2,*,out=None)
③torch.matmul(input, other, *, out=None)
④Tensor.masked_fill
2、torch.mm
torch.mm語法為:
torch.mm(input, mat2, *, out=None) → Tensor就是矩陣的乘法。如果輸入input是(n,m),mat2是(m, p),則輸出為(n, p)。
示例:
3、torch.bmm
torch.bmm語法為:
torch.bmm(input, mat2, *, out=None) → Tensor- 功能:對存儲在input和mat2矩陣中的批數量的矩陣進行乘積。
- 要求:input矩陣和mat2必須是三維的張量,且第一個維度即batch維度必須一樣。
- 舉例:如果input是一個(b, n , m)的張量,mat2是一個(b, m, p)張量,則輸出形狀為(b, n, p)
示例:
input = torch.randn(10, 3, 4) mat2 = torch.randn(10, 4, 5) res = torch.bmm(input, mat2) res.size() -->torch.Size([10, 3, 5])解讀:實際上刻畫的就是一組矩陣與另一組張量矩陣的乘積,至于一組有多少個矩陣,由input和mat2的第一個輸入維度決定,上述代碼第一個維度為10,就代表著10個形狀為(3, 4)的矩陣與10個形狀為(4, 5)的矩陣分別對應相乘,得到10個形狀為(3, 5)的矩陣。
4、torch.matmul
torch.matmul語法為:
torch.matmul(input, other, *, out=None) → Tensor該函數刻畫的是兩個張量的乘積,且計算過程與張量的維度密切相關。
① 如果張量是一維的,輸出結果是點乘,是一個標量。
a = torch.tensor([1,2,4]) b = torch.tensor([2,5,6]) print(torch.matmul(a, b)) print(a.shape) --> tensor(36) -->torch.Size([3])注意:張量a.shape顯示的是torch.Size([3]),只有一個維度,3是指這個維度中有3個數。
② 如果兩個張量都是二維的,執行的是矩陣的乘法。
由上述示例可知,如果兩個張量均為2維,那么其運算和torch.mm是一樣的。
③如果第一個參數input是1維的,第二個參數是二維的,那么在計算時,在第一個參數前增加一個維度1,計算完畢之后再把這個維度去掉。
如上所示,a只有一個維度,在進行計算時,變成了(1, 3),則變成了(1, 3)乘以(3, 2),變成(1, 2),最后在去掉1這個維度。
④如果第一個參數是2維的,第二個參數是1維的,則返回矩陣-向量乘積。
矩陣乘以張量,就是矩陣中的每一行都與這個張量相乘,最終得到一個一維的,大小為3的結果。
⑤多個維度
- 如果兩個參數至少都是1維的,且有一個參數的維度N>2,則返回的是一個批矩陣的乘積(即把多出的那個維度看作batch即可,讓每個batch后的矩陣與后面的張量相乘即可)。
- 如果第一個參數是1維的,則在它的維度前加上1,以便批量矩陣相乘并在之后刪除。如果第二個參數是1維的,則將1追加到其維度,用于批處理矩陣倍數,然后刪除。
- 舉例:如果input形狀是(j,1,n,n),other的張量形狀是(k,n,n),那么輸出張量的形狀將會是(j,k,n,n)。
- 如果input形狀是(j,1,n,m),other的張量形狀是(k,m,p),那么輸出張量的形狀將會是(j,k,n,p)。
仔細比較上述三個代碼塊,其最終的結果是一樣的。可以簡單記為如果兩個維度不一致的話,多出的維度就看作是batch維,相當于在低維度前面增加一個維度。
5、masked_fill
語法為:
Tensor.masked_fill_(mask, value)參數:
- mask(BoolTensor):布爾掩碼
- value(float):用于填充的值。
mask是一個pytorch張量,元素是布爾值,value是要填充的值,填充規則是mask中取值為True的位置對應與需要填充的張量中的位置用value填充。
a = torch.tensor([[0, 8],[ 6, 8],[ 7, 1] ])mask = torch.tensor([[ True, False],[False, False],[False, True] ]) b = a.masked_fill(mask, -1e9) print(b) -->tensor([[-1000000000, 8],[ 6, 8],[ 7, -1000000000]])總結
以上是生活随笔為你收集整理的Pytorch教程之torch.mm、torch.bmm、torch.matmul、masked_fill的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 国庆节堕落的日子
- 下一篇: 精品基于Uniapp+SSM实现的公园植