TensorFlow:沿轴的张量最大值

37

我的问题由两个相关部分组成:

  1. 如何计算张量沿着特定轴的最大值?例如,如果我有以下张量:

    x = tf.constant([[1,220,55],[4,3,-1]])
    

    我希望你能提供类似的东西

    x_max = tf.max(x, axis=1)
    print sess.run(x_max)
    
    output: [220,4]
    

    我知道有tf.argmaxtf.maximum,但是它们都不能返回单个张量沿着某个轴的最大值。目前我只能用一个变通方法:

    x_max = tf.slice(x, begin=[0,0], size=[-1,1])
    for a in range(1,2):
        x_max = tf.maximum(x_max , tf.slice(x, begin=[0,a], size=[-1,1]))
    

    但它看起来不太理想。有更好的方法吗?

  2. 假设有一个张量的argmax的索引,如何使用这些索引来索引另一个张量?以上面的x为例,我应该如何做类似以下操作:

    ind_max = tf.argmax(x, dimension=1)    #output is [1,0]
    y = tf.constant([[1,2,3], [6,5,4])
    y_ = y[:, ind_max]                     #y_ should be [2,6]
    

    我知道像最后一行这样的切片在TensorFlow中还不存在(#206)。

    我的问题是:对于我具体的情况,什么是最好的解决方法(也许可以使用其他方法,如gather、select等)?

    额外信息:我知道xy只会是二维张量!

2个回答

75

tf.reduce_max()操作符提供了这种功能。默认情况下,它计算给定张量的全局最大值,但您可以指定一个reduction_indices列表,该列表与NumPy中的axis具有相同的含义。要完成您的示例:

x = tf.constant([[1, 220, 55], [4, 3, -1]])
x_max = tf.reduce_max(x, reduction_indices=[1])
print sess.run(x_max)  # ==> "array([220,   4], dtype=int32)"
如果你使用tf.argmax()计算argmax,你可以通过使用tf.reshape()将张量y展平,并按如下方式将argmax索引转换为向量索引,然后使用 tf.gather()提取相应的值:
ind_max = tf.argmax(x, dimension=1)
y = tf.constant([[1, 2, 3], [6, 5, 4]])

flat_y = tf.reshape(y, [-1])  # Reshape to a vector.

# N.B. Handles 2-D case only.
flat_ind_max = ind_max + tf.cast(tf.range(tf.shape(y)[0]) * tf.shape(y)[1], tf.int64)

y_ = tf.gather(flat_y, flat_ind_max)

print sess.run(y_) # ==> "array([2, 6], dtype=int32)"

为了完整起见,在您的第二个示例中,请在第三行添加:amax = tf.argmax(y, 1) 并删除第一行。 - Sami A. Haija
6
reduction_indices 已被弃用,请使用 axis 替代。 - Binu Jasim

3

TensorFlow 1.10.0-dev20180626版本开始,tf.reduce_max函数接受axiskeepdims参数,提供了与numpy.max类似的功能。

In [55]: x = tf.constant([[1,220,55],[4,3,-1]])

In [56]: tf.reduce_max(x, axis=1).eval() 
Out[56]: array([220,   4], dtype=int32)

为了得到与输入张量相同维度的结果张量,请使用keepdims=True
In [57]: tf.reduce_max(x, axis=1, keepdims=True).eval()Out[57]: 
array([[220],
       [  4]], dtype=int32)

如果未明确指定axis参数,则返回张量级别的最大元素(即所有轴都被缩减) 。
In [58]: tf.reduce_max(x).eval()
Out[58]: 220

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接