pytorch中数组维度的理解
生活随笔
收集整理的這篇文章主要介紹了
pytorch中数组维度的理解
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
pytorch中數組維度理解與numpy中類似,pytorch中維度用dim表示,numpy中用axis表示
這里主要想說下維度的變化。
dim = x ,表示在第x為上進行操作,那個維度會發生變化。
一、二維數組
1. 兩個二維數組的拼接
維度為(2,3)與(2,4)的數組拼接后的維度是(2,7)
import torch a = torch.Tensor(np.arange(6).reshape(2,3)) b = torch.Tensor(np.arange(8).reshape(2,4)) print(a,'\n ',a.shape) print(b,'\n',b.shape) c = torch.cat((a,b),dim = 1) print('concatenate:\n',c,'\n',c.shape)結果
tensor([[0., 1., 2.],[3., 4., 5.]]) a: torch.Size([2, 3]) tensor([[0., 1., 2., 3.],[4., 5., 6., 7.]]) torch.Size([2, 4]) concatenate:tensor([[0., 1., 2., 0., 1., 2., 3.],[3., 4., 5., 4., 5., 6., 7.]]) torch.Size([2, 7])2. 二維數組求sum、max等
dim = 0,第一個維度劃掉,得到一個一維向量。比如,a是(2,3),dim = 0,得到的結果是(3,)維的;如果dim=1,得到的結果是(2,)
print('sum dim=0',torch.sum(a,dim=0)) print('sum dim=1',torch.sum(a,dim=1)) print('******* max *****') print('max dim=0',torch.max(a,dim=0)) print('max dim=1',torch.max(a,dim=1))輸出
tensor([[0., 1., 2.],[3., 4., 5.]]) torch.Size([2, 3]) sum dim=0 tensor([3., 5., 7.]) sum dim=1 tensor([ 3., 12.]) ******* max ***** max dim=0 torch.return_types.max( values=tensor([3., 4., 5.]), indices=tensor([1, 1, 1])) max dim=1 torch.return_types.max( values=tensor([2., 5.]), indices=tensor([2, 2]))二、三維數組
1. 兩個三維數組的拼接
兩個三位數組拼接,有個要求,除了dim維,其余維的維度要相同。
- 比如 a是(2,3,4),b是(3,2,4)那么a與b無論在哪個維上都不能拼接。因為它們沒有兩個相同的維度。
- 如果a與b維度相同,都是(2,3,4),那么他們無論在哪個維上都可以拼接。dim = 0,結果是(4,3,4),dim = 1,結果是(2,6,4),dim =2,結果是(2,3,8)
- dim = x,就將兩個數組dim維上的數字相加,得到最終輸出維度。
輸出結果
tensor([[[ 0., 1., 2., 3.],[ 4., 5., 6., 7.],[ 8., 9., 10., 11.]],[[12., 13., 14., 15.],[16., 17., 18., 19.],[20., 21., 22., 23.]]]) torch.Size([2, 3, 4]) tensor([[[24., 25., 26., 27.],[28., 29., 30., 31.],[32., 33., 34., 35.]],[[36., 37., 38., 39.],[40., 41., 42., 43.],[44., 45., 46., 47.]]]) torch.Size([2, 3, 4]) concatenate:tensor([[[ 0., 1., 2., 3., 24., 25., 26., 27.],[ 4., 5., 6., 7., 28., 29., 30., 31.],[ 8., 9., 10., 11., 32., 33., 34., 35.]],[[12., 13., 14., 15., 36., 37., 38., 39.],[16., 17., 18., 19., 40., 41., 42., 43.],[20., 21., 22., 23., 44., 45., 46., 47.]]]) torch.Size([2, 3, 8])2. 三維數組求sum、max等
- 類似于二維數組,會消去dim維度
- shape=(2,3,4)的數組,在dim=0上求和或者取最大后,結果的shape = (3,4)
- pytorch求max,同時返回兩個值(max,indices)
結果
tensor([[[ 0., 1., 2., 3.],[ 4., 5., 6., 7.],[ 8., 9., 10., 11.]],[[12., 13., 14., 15.],[16., 17., 18., 19.],[20., 21., 22., 23.]]]) torch.Size([2, 3, 4]) sum dim=0 tensor([[12., 14., 16., 18.],[20., 22., 24., 26.],[28., 30., 32., 34.]]) sum dim=1 tensor([[12., 15., 18., 21.],[48., 51., 54., 57.]]) sum dim=2 tensor([[ 6., 22., 38.],[54., 70., 86.]]) ******* max ***** max dim=0 torch.return_types.max( values=tensor([[12., 13., 14., 15.],[16., 17., 18., 19.],[20., 21., 22., 23.]]), indices=tensor([[1, 1, 1, 1],[1, 1, 1, 1],[1, 1, 1, 1]])) max dim=1 torch.return_types.max( values=tensor([[ 8., 9., 10., 11.],[20., 21., 22., 23.]]), indices=tensor([[2, 2, 2, 2],[2, 2, 2, 2]])) max dim=2 torch.return_types.max( values=tensor([[ 3., 7., 11.],[15., 19., 23.]]), indices=tensor([[3, 3, 3],[3, 3, 3]]))總結
以上是生活随笔為你收集整理的pytorch中数组维度的理解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: numpy中的加法
- 下一篇: permute、transpose、vi