Numpy:如何使用argmax结果获取实际的最大值?

7
假设我有一个三维数组:
>>> a
array([[[7, 0],
        [3, 6]],

       [[2, 4],
        [5, 1]]])

我可以使用 argmaxaxis=1 上获取它的值。
>>> m = np.argmax(a, axis=1)
>>> m
array([[0, 1],
       [1, 0]])

如何使用m作为索引进入a,使结果等同于直接使用max

>>> a.max(axis=1)
array([[7, 6],
       [5, 4]])

当应用m到其他相同形状的数组时,这将非常有用。

2个回答

5
你可以通过 高级索引numpy广播 实现此操作:
m = np.argmax(a, axis=1)
a[np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])]

#array([[7, 6],
#       [5, 4]])

m = np.argmax(a, axis=1)

创建一维、二维和三维索引数组:
ind1, ind2, ind3 = np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])
​

由于维度不匹配,三个数组将进行广播,结果如下:
for x in np.broadcast_arrays(ind1, ind2, ind3):
    print(x, '\n')

#[[0 0]
# [1 1]] 

#[[0 1]
# [1 0]] 

#[[0 1]
# [0 1]] 

由于所有的索引都是整数数组,因此它触发了高级索引,所以选择具有索引(0, 0, 0), (0, 1, 1), (1, 1, 0), (1, 0, 1)的元素,即每个数组中组合成的一个元素。


感谢您展示 np.broadcast_arrays - Alex Kreimer

3
你可以使用 np.ogrid 为除了你减少的轴以外的所有轴创建一个网格。然后只需将 argmax 的结果插入到你的轴的位置,并使用该结果对数组进行索引:index
>>> import numpy as np
>>> a = np.array([[[7, 0], [3, 6]], [[2, 4], [5, 1]]])
>>> axis = 1

>>> # Create the grid
>>> idx = list(np.ogrid[[slice(a.shape[ax]) for ax in range(a.ndim) if ax != axis]])
>>> argmaxes = np.argmax(a, axis=axis)
>>> idx.insert(axis, argmaxes)

>>> # Index the original array with the grid
>>> a[idx]
array([[7, 6],
       [5, 4]])

请注意,对于 axis=None 或在您减少多个轴的情况下,此方法不适用。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接