如何从numpy多维数组中获取前k个最大值的索引

9

我在StackOverflow上看了几个问题,但没有找到相关的答案。 我想从numpy ndarray中获取k个最大值的索引。这个链接讨论了相同的问题,但是只针对1D数组。 对于2D数组,np.argsort会按行排序元素。

Note: array elements are not unique.

输入:

import numpy as np
n = np.arange(9).reshape(3,3)
>>> n
array([[0, 1, 2],
   [3, 4, 5],
   [6, 7, 8]])
s = n.argsort()
>>> s
array([[0, 1, 2],
   [0, 1, 2],
   [0, 1, 2]], dtype=int32)

此外,
import numpy as np
n = np.arange(9).reshape(3,3)
s = n.argsort(axis=None)
>>>s
array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=int32)

但我在这里失去了数组结构,无法恢复元素的原始索引。
感谢任何形式的帮助。

你需要针对2D数组或n维数组(其中n>2)的答案吗? - Chris
1个回答

12

使用 np.argpartitionnp.argsort 两种方法对 ndarrays 进行排序 -

def k_largest_index_argpartition_v1(a, k):
    idx = np.argpartition(-a.ravel(),k)[:k]
    return np.column_stack(np.unravel_index(idx, a.shape))

def k_largest_index_argpartition_v2(a, k):
    idx = np.argpartition(a.ravel(),a.size-k)[-k:]
    return np.column_stack(np.unravel_index(idx, a.shape))

def k_largest_index_argsort(a, k):
    idx = np.argsort(a.ravel())[:-k-1:-1]
    return np.column_stack(np.unravel_index(idx, a.shape))

讨论使用argpartition的两个版本

k_largest_index_argpartition_v1k_largest_index_argpartition_v2之间的差异在于我们如何使用argparition。在第一个版本中,我们对输入数组取负值,然后使用argpartition获取最小的k个索引,从而有效地获取最大的k个索引;而在第二个版本中,我们获取前a.size-k个最小的索引,然后选择剩余的最大的k个索引。

此外,在使用argpartition时,值得注意的是我们不能按照它们的排序顺序获取索引。如果需要按排序顺序获取索引,则需要将范围数组提供给np.argpartition,正如这个post中所述。

样例运行 -

1) 二维情况:

In [42]: a    # 2D array
Out[42]: 
array([[38, 14, 81, 50],
       [17, 65, 60, 24],
       [64, 73, 25, 95]])

In [43]: k_largest_index_argsort(a, k=2)
Out[43]: 
array([[2, 3],
       [0, 2]])

In [44]: k_largest_index_argsort(a, k=4)
Out[44]: 
array([[2, 3],
       [0, 2],
       [2, 1],
       [1, 1]])

In [66]: k_largest_index_argpartition_v1(a, k=4)
Out[66]: 
array([[2, 1], # Notice the order is different
       [2, 3],
       [0, 2],
       [1, 1]])

2) 三维情况:

In [46]: a # 3D array
Out[46]: 
array([[[20, 98, 27, 73],
        [33, 78, 48, 59],
        [28, 91, 64, 70]],

       [[47, 34, 51, 19],
        [73, 38, 63, 94],
        [95, 25, 93, 64]]])

In [47]: k_largest_index_argsort(a, k=2)
Out[47]: 
array([[0, 0, 1],
       [1, 2, 0]])

运行时测试 -

In [56]: a = np.random.randint(0,99999999999999,(3000,4000))

In [57]: %timeit k_largest_index_argsort(a, k=10)
1 loops, best of 3: 2.18 s per loop

In [58]: %timeit k_largest_index_argpartition_v1(a, k=10)
10 loops, best of 3: 178 ms per loop

In [59]: %timeit k_largest_index_argpartition_v2(a, k=10)
10 loops, best of 3: 128 ms per loop

仅供澄清,argpartition 不会按顺序给出前 k 个,而只是按其最初的顺序给出前 k 个。既能获得 argpartition 的速度优势,又能获得 argsort 的排序优势的方法是,在分区之后进行排序:执行 idx2=np.argsort(-a.ravel()[idx]) ,然后 row,col = np.unravel.index(idx[idx2], a.shape) - Daniel F
此外,在这里执行-a.ravel()的成本有些显著。 - Eric
@Eric 没错,我发现了一个奇怪的事情,对于 k=2argpartition 的运行时间很短,但是对于 3 及以上的值,它会飙升并在此后保持大致恒定。因此,如果我们执行以下操作:a = np.random.randint(0,99999999999999,(10000000)) 然后 np.argpartition(a,2) 的时间为 39毫秒,而当 k=3 时,时间则增加到了 90.6毫秒 - Divakar
@Eric,我在argpartition方面做出了错误的假设,已经将其删除。感谢您指出! - Divakar
1
@mLstudent33 当我在这里引用其他地方的代码时,我会在函数代码顶部放置链接和作者名称,并使用 @ 符号。就像这里 - https://stackoverflow.com/a/51951787/。因此,我想这种方式也可以用于作业或 Stackoverflow 之外的场合。 - Divakar
显示剩余6条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接