Numpy 3D数组最大值

3
import numpy as np
a = np.array([[[ 0.25,  0.10 ,  0.50 ,  0.15],
           [ 0.50,  0.60 ,  0.70 ,  0.30]],
          [[ 0.25,  0.50 ,  0.20 ,  0.70],
           [ 0.80,  0.10 ,  0.50 ,  0.15]]])

我需要找到a[i]中最大值的行和列。 如果i=0,那么a[0,1,2]是最大值,所以我需要编写一个方法,使得在a[0]中最大值的输出为[1,2]。请给出任何指针? 注意:np.argmax会将a[i] 2D数组展开,当axis=0时,它会给出a[0]每行最大值的索引。


我建议看一下numpy.argmax。这可以通过获取索引,然后使用数组的形状进行一些快速计算来完成。 - David
3个回答

4

您也可以使用 argmax 结合 unravel_index 来实现:

def max_by_index(idx, arr):
    return (idx,) + np.unravel_index(np.argmax(arr[idx]), arr.shape[1:])

e.g.

import numpy as np
a = np.array([[[ 0.25,  0.10 ,  0.50 ,  0.15],
               [ 0.50,  0.60 ,  0.70 ,  0.30]],
              [[ 0.25,  0.50 ,  0.20 ,  0.70],
               [ 0.80,  0.10 ,  0.50 ,  0.15]]])

def max_by_index(idx, arr):
    return (idx,) + np.unravel_index(np.argmax(arr[idx]), arr.shape[1:])


print(max_by_index(0, a))

提供

(0, 1, 2)


太晚了,我正要发布这个。和你在编辑中纠正的错误一样 ;) - JHBonarius

2
您可以使用numpy.where,您可以将其包装在一个简单的函数中以满足您的要求:
def max_by_index(idx, arr):
    return np.where(arr[idx] == np.max(arr[idx]))

实际应用:

>>> max_by_index(0, a)
(array([1], dtype=int64), array([2], dtype=int64))

您可以使用此结果对数组进行索引,以访问最大值:
>>> a[0][max_by_index(0, a)]
array([0.7])

如果你只想要最大值出现一次,可以将np.max替换为np.argmax。这将返回最大值的所有位置,而不是仅返回一个位置。


请注意,使用==比较浮点数时,np.where可能根本不返回任何内容。 - dtasev

0
col = (np.argmax(a[i])) % (a[i].shape[1])
row = (np.argmax(a[i])) // (a[i].shape[1])

这也有所帮助


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接