获取numpy数组中的最大或最小n个元素?(最好不要展平)

7

我知道可以使用以下方法获取最小值或最大值:

max(matrix)
min(matrix)

从numpy矩阵/向量中筛选出特定的值,可以使用以下代码获取这些值的索引:
argmax(matrix)
argmin(matrix)

比如,当我有一个5x5的矩阵:

a = np.arange(5*5).reshape(5, 5) + 10

# array([[10, 11, 12, 13, 14],
#        [15, 16, 17, 18, 19],
#        [20, 21, 22, 23, 24],
#        [25, 26, 27, 28, 29],
#        [30, 31, 32, 33, 34]])

我可以通过以下方式获取最大值:

In [86]: np.max(a) # getting the max-value out of a
Out[86]: 34

In [87]: np.argmax(a) # index of max-value 34 is 24 if array a were flattened
Out[87]: 24

...但是获取前n个最大或最小元素的最有效方法是什么?

假设我想从a中获取5个最高和5个最低元素。这应该分别返回[30, 31, 32, 33, 34]作为5个最高值,以及它们的索引[20, 21, 22, 23, 24]。同样地,对于5个最低值,应返回[10, 11, 12, 13, 14]和其5个最低元素的索引[0, 1, 2, 3, 4]

有没有一种高效、合理的解决方案呢?

我的第一个想法将数组展平并排序,然后取最后和最前的5个值。之后通过原始的2D矩阵查找这些值的索引。虽然这个过程是可行的,但展平+排序并不是很高效...是否有更快的解决方案?

此外,我想要原始2D数组的索引,而不是展平后的索引。所以,我想要的是(4, 4)而不是np.argmax(a)返回的24


1
np.partition(以及用于索引的np.argpartition)的时间复杂度为O(n) - 我认为这是您在此处所能期望的最好结果。它需要先对数组进行展平操作(这应该只会创建一个视图,因此不会产生任何性能损失)。然后,您可以使用unravel_index来获取原始数组中的二维索引。 - Alex Riley
1个回答

4
使用 np.argpartition 是获取数组中最大或最小值的索引的标准方法。此函数使用 introselect 算法,并以线性复杂度运行 - 对于较大的数组,这比完全排序的表现更好(通常为 O(n log n))。
默认情况下,此函数沿数组的最后一个轴运行。要考虑整个数组,您需要使用 ravel()。例如,这是一个随机数组a
>>> a = np.random.randint(0, 100, size=(5, 5))
>>> a
array([[60, 68, 86, 66,  9],
       [66, 26, 83, 87, 50],
       [41, 26,  0, 55,  9],
       [57, 80, 71, 50, 22],
       [94, 30, 95, 99, 76]])

如果要获取(扁平化的)二维数组中前五个最大值的索引,可以使用以下代码:

>>> i = np.argpartition(a.ravel(), -5)[-5:] # argpartition(a.ravel(), 5)[:5] for smallest
>>> i
array([ 2,  8, 22, 23, 20])

要获取在数组 a 中对应这些位置的二维索引,请使用 unravel_index 方法:
>>> i2d = np.unravel_index(i, a.shape)
>>> i2d
(array([0, 1, 4, 4, 4]), array([2, 3, 2, 3, 0]))

然后使用i2d索引a将返回五个最大的值:

>>> a[i2d]
array([86, 87, 95, 99, 94])

在这种情况下,排序更快:%timeit a.ravel.argpartition(-5) -> 5.5 µs%timeit a.ravel.argsort() -> 3.8 µs。但是当然对于更大的数组,这才是正确的方法。 - B. M.

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接