tf.gather( )的详细解析
生活随笔
收集整理的這篇文章主要介紹了
tf.gather( )的详细解析
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
tf.gather()函數
tf.gather() 該接口的作用:就是抽取出params的第axis維度上在indices里面所有的indextf.gather(params,indices,validate_indices=None,name=None,axis=0 )''' Args:params: A Tensor. The tensor from which to gather values. Must be at least rank axis + 1.indices: A Tensor. Must be one of the following types: int32, int64. Index tensor. Must be in range [0, params.shape[axis]).axis: A Tensor. Must be one of the following types: int32, int64. The axis in params to gather indices from. Defaults to the first dimension. Supports negative indexes.name: A name for the operation (optional). Returns:A Tensor. Has the same type as params. '''參數說明:
params: A Tensor.
indices: A Tensor. types必須是: int32, int64. 里面的每一個元素大小必須在 [0, params.shape[axis])范圍內.
axis: 維度。沿著params的哪一個維度進行抽取indices
返回的是一個tensor
示例
1.當axis參數省略時,即axis默認為axis=0
temp=np.arange(7)*7*2+tf.constant(1,shape=[7]) print(temp) #temp是一個tf張量,直接打印只顯示形狀和類型,需要創建會話機制temp1=tf.gather(temp,[0,2,4]) #取temp一維張量中第0,第2,第4的值形成一個tf張量with tf.Session() as sess:print(sess.run(temp))print(sess.run(temp1))結果:
Tensor("add:0", shape=(7,), dtype=int32) [ 1 15 29 43 57 71 85] [ 1 29 57]2.當temp為多維時,這里是4維張量,indices=[0,2],axis=0
input =[ [[[1, 1, 1], [2, 2, 2]],[[3, 3, 3], [4, 4, 4]],[[5, 5, 5], [6, 6, 6]]],[[[7, 7, 7], [8, 8, 8]],[[9, 9, 9], [10, 10, 10]],[[11, 11, 11], [12, 12, 12]]],[[[13, 13, 13], [14, 14, 14]],[[15, 15, 15], [16, 16, 16]],[[17, 17, 17], [18, 18, 18]]]]print(tf.shape(input)) with tf.Session() as sess:output=tf.gather(input, [0,2],axis=0)#其實默認axis=0print(sess.run(output))結果:
Tensor("Shape:0", shape=(4,), dtype=int32) [[[[ 1 1 1][ 2 2 2]][[ 3 3 3][ 4 4 4]][[ 5 5 5][ 6 6 6]]][[[13 13 13][14 14 14]][[15 15 15][16 16 16]][[17 17 17][18 18 18]]]]解釋:
第一個[ 是列表語法需要的括號,剩下的最里面的三個[[[是axis=0需要搜尋的中括號。這里一共有3個[[[。
indices的[0,2]即取第0個[[[和第2個[[[,也就是第0個和第2個三維立體。
3.當indices=[0,2],axis=1
input =[ [[[1, 1, 1], [2, 2, 2]],[[3, 3, 3], [4, 4, 4]],[[5, 5, 5], [6, 6, 6]]],[[[7, 7, 7], [8, 8, 8]],[[9, 9, 9], [10, 10, 10]],[[11, 11, 11], [12, 12, 12]]],[[[13, 13, 13], [14, 14, 14]],[[15, 15, 15], [16, 16, 16]],[[17, 17, 17], [18, 18, 18]]]] print(tf.shape(input)) with tf.Session() as sess:output=tf.gather(input, [0,2],axis=1)#默認axis=0print(sess.run(output))結果:
Tensor("Shape:0", shape=(4,), dtype=int32) [[[[ 1 1 1][ 2 2 2]][[ 5 5 5][ 6 6 6]]][[[ 7 7 7][ 8 8 8]][[11 11 11][12 12 12]]][[[13 13 13][14 14 14]][[17 17 17][18 18 18]]]]解釋:
第一個[ 是列表語法需要的括號,先把這個干擾去掉,剩下的每個[[[中所有內側的 [[ 是axis=1搜索的中括號。
然后[0,2]即再取每個[[[體內的第0個[[和第2個[[,也就是去每個三維體的第0個面和第2個面
總結
以上是生活随笔為你收集整理的tf.gather( )的详细解析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Packet Tracer 思科模拟器之
- 下一篇: tf.boolean_mask()的详细