如何在tf.Tensor中找到最大值?

3

我该如何找到每个元素中的最大值,以便得到2、4、6、8?

import tensorflow as tf
a = tf.constant([
  [[1, 2]], [[3, 4]],
  [[5, 6]], [[7, 8]]])

我尝试了以下代码:
tf.reduce_max(a, keepdims=True)

但这只会输出8而忽略其余部分。

reduce_max有一个axis参数,可以尝试使用它。 - xdurch0
1个回答

4

您需要将axis参数更改为-1,像这样:

import tensorflow as tf
a = tf.constant([
  [[1, 2]], [[3, 4]],
  [[5, 6]], [[7, 8]]])

print(tf.reduce_max(a, axis=-1, keepdims=False))

'''
tf.Tensor(
[[2]
 [4]
 [6]
 [8]], shape=(4, 1), dtype=int32)
'''

由于您有一个三维张量并且想访问最后一个维度。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接