『PyTorch』矩阵乘法总结
1. 二維矩陣乘法 torch.mm()
torch.mm(mat1, mat2, out=None),其中mat1((n imes m)),mat2((m imes d)),輸出out的維度是((n imes d))。
該函數一般只用來計算兩個二維矩陣的矩陣乘法,并且不支持broadcast操作。
2. 三維帶batch的矩陣乘法 torch.bmm()
由于神經網絡訓練一般采用mini-batch,經常輸入的時三維帶batch的矩陣,所以提供torch.bmm(bmat1, bmat2, out=None),其中bmat1((b imes n imes m)),bmat2((b imes m imes d)),輸出out的維度是((b imes n imes d))。
該函數的兩個輸入必須是三維矩陣且第一維相同(表示Batch維度),不支持broadcast操作。
3. 多維矩陣乘法 torch.matmul()
torch.matmul(input, other, out=None)支持broadcast操作,使用起來比較復雜。
針對多維數據 matmul()乘法,我們可以認為該matmul()乘法使用使用兩個參數的后兩個維度來計算,其他的維度都可以認為是batch維度。假設兩個輸入的維度分別是input((1000 imes 500 imes 99 imes 11)), other((500 imes 11 imes 99))那么我們可以認為torch.matmul(input, other, out=None)乘法首先是進行后兩位矩陣乘法得到((99 imes 11) imes (11 imes 99)Rightarrow(99 imes 99)) ,然后分析兩個參數的batch size分別是 (( 1000 imes 500)) 和 (500) , 可以廣播成為 ((1000 imes 500)), 因此最終輸出的維度是((1000 imes 500 imes 99 imes 99))。
4. 矩陣逐元素(Element-wise)乘法 torch.mul()
torch.mul(mat1, other, out=None),其中other乘數可以是標量,也可以是任意維度的矩陣,只要滿足最終相乘是可以broadcast的即可
5. 兩個運算符 @ 和 *
@:矩陣乘法,自動執行適合的矩陣乘法函數
*:element-wise乘法
總結
以上是生活随笔為你收集整理的『PyTorch』矩阵乘法总结的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 怎么跳过Windows 10/8/7登录
- 下一篇: python plt_python的pl