gather torch_浅谈Pytorch中的torch.gather函数的含义
pytorch中的gather函數
pytorch比tensorflow更加編程友好,所以準備用pytorch試著做最近要做的一些實驗。
立個flag開始學習pytorch,新開一個分類整理學習pytorch中的一些踩到的泥坑。
今天剛開始接觸,讀了一下documentation,寫一個一開始每太搞懂的函數gather
b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)
觀察它的輸出結果:
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
1 2
6 4
[torch.FloatTensor of size 2x2]
1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
這里是官方文檔的解釋
torch.gather(input, dim, index, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
Parameters:
input (Tensor) ? The source tensor
dim (int) ? The axis along which to index
index (LongTensor) ? The indices of elements to gather
out (Tensor, optional) ? Destination tensor
Example:
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
可以看出,gather的作用是這樣的,index實際上是索引,具體是行還是列的索引要看前面dim 的指定,比如對于我們的栗子,【1,2,3;4,5,6,】,指定dim=1,也就是橫向,那么索引就是列號。index的大小就是輸出的大小,所以比如index是【1,0;0,0】,那么看index第一行,1列指的是2, 0列指的是1,同理,第二行為4,4 。這樣就輸入為【2,1;4,4】,參考這樣的解釋看上面的輸出結果,即可理解gather的含義。
gather在one-hot為輸出的多分類問題中,可以把最大值坐標作為index傳進去,然后提取到每一行的正確預測結果,這也是gather可能的一個作用。
以上這篇淺談Pytorch中的torch.gather函數的含義就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
總結
以上是生活随笔為你收集整理的gather torch_浅谈Pytorch中的torch.gather函数的含义的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: saiku 连接 MySQL_Saiku
- 下一篇: 华硕推出 Zenbook S 13 OL