pytorch学习——torch.cat和torch.stack的区别
合并tensors
- torch.cat 沿著特定維數連接一系列張量。
- torch.stack 沿新維度連接一系列張量。
torch.cat
在給定維度中連接給定的 seq 個張量序列。
所有張量必須具有相同的形狀(連接維度除外)或為空。
torch.cat(tensors, dim=0, *, out=None) → Tensor
參數
- tensors(張量序列):任何相同類型的張量序列。 提供的非空張量必須具有相同的形狀。在給定維度上對輸入的張量序列進行連接操作。
- dim (int) : 張量連接的維度,
torch.stack
沿新維度連接一系列張量。(維度疊加)
所有張量都需要具有相同的大小。
torch.stack(tensors, dim=0, *, out=None) → Tensor
參數
- tensors(張量序列):要連接的張量序列
- dim (int) : 要插入的維度。必須介于 0 和串聯張量的維數之間(含)
示例
沿第0維操作:
import torchx1 = torch.tensor([[1,2,3], [4,5,6]])# x1.shape = tensor.size([2,3])
x2 = torch.tensor([[7,8,9], [10,11,12]])# x2.shape = tensor.size([2,3])
print(x1.shape)
print('沿第0維進行操作:')
y1 = torch.cat([x1, x2], dim=0)
y2 = torch.stack([x1, x2], dim=0)
print('cat, y1:', y1.shape,'\n',y1)
print('stack, y2:', y2.shape,'\n',y2)
輸出:
沿第0維進行操作:
cat, y1: torch.Size([4, 3]) tensor([[ 1, 2, 3],[ 4, 5, 6],[ 7, 8, 9],[10, 11, 12]])
stack, y2: torch.Size([2, 2, 3]) tensor([[[ 1, 2, 3],[ 4, 5, 6]],[[ 7, 8, 9],[10, 11, 12]]])
從y1的輸出可以看到,cat在第0維將x1和x2元素進行續接,即輸出為[x1[0], x1[1], x2[0], x2[1]], shape由[2, 3]變為[4,3]。
從y2的輸出可以看到,stack直接將x1和x2的第0維進行疊加,即輸出為[x1, x2],shape由[2,3]變為[2, 2, 3]。
沿第1維操作:
print('沿第1維進行操作:')
y1 = torch.cat(x, dim=1)
y2 = torch.stack(x, dim=1)
print('cat, y1:', y1.shape,'\n',y1)
print('stack, y2:', y2.shape,'\n',y2)
輸出:
沿第1維進行操作:
cat, y1: torch.Size([2, 6]) tensor([[ 1, 2, 3, 7, 8, 9],[ 4, 5, 6, 10, 11, 12]])
stack, y2: torch.Size([2, 2, 3]) tensor([[[ 1, 2, 3],[ 7, 8, 9]],[[ 4, 5, 6],[10, 11, 12]]])
從y1的輸出可以看到,cat將x1和x2相對應的第1維的元素進行續接, shape由[2,3]變為[2, 6]。
從y2的輸出可以看到,stack直接將x1和x2相對應的第1維的元素進行疊加,即輸出為[[x1[0], x2[0]], [x1[1], x2[1]],shape由[2,3]變為[2, 2, 3]。
沿第2維操作:
輸出
y1 = torch.cat(x, dim=2)
print('cat, y1:', y1.shape,'\n',y1)Traceback (most recent call last):File "/Users/gyuer/Desktop/test.py", line 8, in <module>y1 = torch.cat(x, dim=2)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
y2 = torch.stack(x, dim=2)
print('stack, y2:', y2.shape,'\n',y2)stack, y2: torch.Size([2, 3, 2]) tensor([[[ 1, 7],[ 2, 8],[ 3, 9]],[[ 4, 10],[ 5, 11],[ 6, 12]]])
從以上結果可以看出,torch.stack(x, dim=2)是將x1[i][j]和x2[i][j]堆疊在一起的。如x1[0][0]=1和x2[0][0]=7堆疊在一起,得到[1, 7]。
stack的參數dim要插入的維度必須介于 0 和串聯張量的維數之間
以上總結借鑒了官網的英文解釋和https://blog.csdn.net/weixin_42920104/article/details/105833691
總結
以上是生活随笔為你收集整理的pytorch学习——torch.cat和torch.stack的区别的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 歌词花开花谢是哪首歌啊?
- 下一篇: 别克lacrosse2.8t后保险杠杠多