Tensorflow tf.gather 函数中的 axis 参数是什么意思?

6
我将使用TensorFlow的tf.gather从一个多维数组中获取元素,例子如下:
import tensorflow as tf

indices = tf.constant([0, 1, 1])
x = tf.constant([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])

result = tf.gather(x, indices, axis=1)

with tf.Session() as sess:
    selection = sess.run(result)
    print(selection)

这将导致:
[[1 2 2]
 [4 5 5]
 [7 8 8]]

不过我想要的是:

[1
 5
 8]

如何使用tf.gather来在指定的轴上应用单一索引? (与此回答中指定的解决方法产生相同的结果:https://dev59.com/wVoV5IYBdhLWcg3wCrLH#41845855

1个回答

2
你需要将indices转换为full indices,并使用gather_nd函数。可以通过以下方式实现:
result = tf.squeeze(tf.gather_nd(x,tf.stack([tf.range(indices.shape[0])[...,tf.newaxis], indices[...,tf.newaxis]], axis=2)))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接