PyTorch 笔记(06)— Tensor 索引操作(index_select、masked_select、non_zero、gather)
Tensor 支持與 numpy.ndarray 類似的索引操作,如無特殊說明,索引出來的結(jié)果與源 tensor 共享內(nèi)存,即修改一個,另外一個也會跟著改變。
In [65]: a = t.arange(0,6).reshape(2,3) In [66]: a
Out[66]:
tensor([[0, 1, 2],[3, 4, 5]])
1. 初級索引
1. 獲取第 0 行
In [67]: a[0] # 第 0 行
Out[67]: tensor([0, 1, 2])
2. 獲取第 0 列
In [68]: a[:,0] # 第 0 列
Out[68]: tensor([0, 3])
3. 獲取第 0 行某個元素
In [69]: a[0][2] # 第 0 行 第 2 個元素
Out[69]: tensor(2)In [70]: a[0,2] # 等價 a[0][2]
Out[70]: tensor(2)In [71]: a[0, -1] # 第 0 行 最后一個元素
Out[71]: tensor(2)
4. 獲取前 1 行
In [72]: a[:1] # 前 1 行
Out[72]: tensor([[0, 1, 2]])In [73]: a[0:1, 0:2] # 第 0 行第 0 列 和第 0 行第 1 列
Out[73]: tensor([[0, 1]])In [74]: a[0:2, 1:2] # 第 0 行第 1 列 和第 1 行第 1 列
Out[74]:
tensor([[1],[4]])In [75]: a[0:2, 0:2]
Out[75]:
tensor([[0, 1],[3, 4]])In [76]:
2. 高級索引
常用選擇函數(shù)如下表所示:
2.1 index_select
index_select(input, dim, index)
input表示輸入的變量;dim表示從第幾維挑選數(shù)據(jù),類型為int值;index表示從選擇維度中的哪個位置挑選數(shù)據(jù),類型為torch.Tensor類的實例;
t.index_select(a, 0, t.tensor([0, 1])) 表示挑選第 0 維,t.tensor([0, 1]) 表示第 0 行、第 1 行
t.index_select(a, 1, t.tensor([1, 3])) 表示挑選第 1 維,t.tensor([1, 3]) 表示第 1 行、第 3 行(第一行從 0 開始計數(shù))
In [9]: a = t.arange(0, 12).reshape(3,4) In [10]: a
Out[10]:
tensor([[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]])In [11]: b = t.index_select(a, 0, t.tensor([0, 1])) In [12]: b
Out[12]:
tensor([[0, 1, 2, 3],[4, 5, 6, 7]])In [13]: a.index_select(0, t.tensor([0, 1]))
Out[13]:
tensor([[0, 1, 2, 3],[4, 5, 6, 7]])In [17]: a.index_select(1, t.tensor([1,3]))
Out[17]:
tensor([[ 1, 3],[ 5, 7],[ 9, 11]])In [18]: c = t.index_select(a, 1, t.tensor([1, 3])) In [19]: c
Out[19]:
tensor([[ 1, 3],[ 5, 7],[ 9, 11]])In [20]:
2.2 masked_select
torch.masked_select(input, mask, out=None)
根據(jù)掩碼張量 mask 中的二元值,取輸入張量中的指定項,將取值返回到一個新的 1D 張量。張量 mask 須跟 input 張量有相同的元素數(shù)目,但形狀或維度不需要相同。返回的張量不與原始張量共享內(nèi)存空間。
input(Tensor)輸入張量;mask(ByteTensor)掩碼張量,包含了二元索引值;out目標張量;
In [1]: import torch as tIn [2]: a = t.arange(0, 6).reshape(2, 3)In [76]: a
Out[76]:
tensor([[0, 1, 2],[3, 4, 5]])In [77]: a > 2
Out[77]:
tensor([[False, False, False],[ True, True, True]])In [78]: a[a>2] # 選擇結(jié)果與源 Tensor 不共享內(nèi)存空間
Out[78]: tensor([3, 4, 5])In [79]: a.masked_select(a>2) # 等價于 a[a>2]
Out[79]: tensor([3, 4, 5])In [80]: a[a>2][0] = 100 In [81]: a
Out[81]:
tensor([[0, 1, 2],[3, 4, 5]])In [82]: a[a>2]
Out[82]: tensor([3, 4, 5])In [83]:
2.3 non_zero
non_zero 返回一個包含輸入 input 中非零元素索引的張量。輸出張量中的每行包含 input 中非零元素的索引。
如果輸入 input 有 n 維,則輸出的索引張量 out 的 size 為 z x n , 這里 z 是輸入張量 input 中所有非零元素的個數(shù)。
In [2]: a = t.arange(0, 6).reshape(2, 3)In [3]: a
Out[3]:
tensor([[0, 1, 2],[3, 4, 5]])In [4]: type(a>2)
Out[4]: torch.TensorIn [5]: a.nonzero()
Out[5]:
tensor([[0, 1],[0, 2],[1, 0],[1, 1],[1, 2]])In [7]: t.nonzero(a!=0)
Out[7]:
tensor([[0, 1],[0, 2],[1, 0],[1, 1],[1, 2]])In [8]:
2.4 gather
收集輸入的特定維度指定位置的數(shù)值。
torch.gather(input, dim, index, out=None) → Tensor
input (Tensor)– 源張量,也就是輸入的待處理變量;dim (int)– 索引的軸,待操作的維度;index (LongTensor)– 聚合元素的下標out (Tensor, optional)– 目標張量
In [1]: import torch as tIn [8]: a = t.arange(0, 6).reshape(2, 3)In [9]: a
Out[9]:
tensor([[0, 1, 2],[3, 4, 5]])In [4]: a.sum(dim=0)
Out[4]: tensor([3, 5, 7])In [5]: a.sum(dim=1)
Out[5]: tensor([ 3, 12])In [12]: t.gather(a, 0, t.LongTensor([[0,1,0], [1,0,0]]))
Out[12]:
tensor([[0, 4, 2],[3, 1, 2]])In [13]: t.gather(a, 1, t.LongTensor([[2,0,1], [1,2,0]]))
Out[13]:
tensor([[2, 0, 1],[4, 5, 3]])In [14]:
由 a.sum(dim=0) 可知當 dim=0 時,是按照列的方向求和的,所以求 t.gather(a, 0, t.LongTensor([[0,1,0], [1,0,0]])) 值時可以按照以下步驟進行:
- 取各個元素的列下標,如
[(x,0), (x,1), (x,2)], [(x,0), (x,1), (x,2)] - 取
t.LongTensor([[0,1,0], [1,0,0]])值作為行下標, 如[(0,0), (1,1), (0,2)], [(1,0), (0,1), (0,2)] - 根據(jù)步驟 2 得到的索引在 input 中求值,即
a[0][0] = 0,a[1][1] = 4,a[0][2] = 2
a[1][0] = 3,a[0][1] = 1,a[0][2] = 2
得到如下值
tensor([[0, 4, 2],[3, 1, 2]])
同理,對于 a.sum(dim=1) 可知當 dim=1 時,是按照行的方向求和的,所以求 t.LongTensor([[2,0,1], [1,2,0]]) 值時可以按照以下步驟進行:
- 取各個元素的行下標,如
[(0,x), (0,x), (0,x)], [(1,x), (1,x), (1,x)] - 取
t.LongTensor([[2,0,1], [1,2,0]])值作為列下標, 如[(0,2), (0,0), (0,1)], [(1,1), (1,2), (1,0)] - 根據(jù)步驟 2 得到的索引在 input 中求值,即
a[0][2] = 2,a[0][0] = 0,a[0][1] = 1
a[1][1] = 4,a[1][2] = 5,a[1][0] = 3
得到如下值
tensor([[2, 0, 1],[4, 5, 3]])
總結(jié)
以上是生活随笔為你收集整理的PyTorch 笔记(06)— Tensor 索引操作(index_select、masked_select、non_zero、gather)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: “慵坐但含情”下一句是什么
- 下一篇: 求一个一个人伤感的微信网名