寻找numpy数组中k个最小值的索引

121
为了找到最小值的索引,我可以使用 argmin:
import numpy as np
A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
print A.argmin()     # 4 because A[4] = 0.1

但是我该如何找到前k个最小值的索引呢?

我需要像这样的东西:

print A.argmin(numberofvalues=3)   
# [4, 0, 7]  because A[4] <= A[0] <= A[7] <= all other A[i]

注意:在我的使用情况下,A具有大约10,000到100,000个值,并且我只对k = 10个最小值的索引感兴趣。 k永远不会> 10。


2
请参考这个问题,特别是其中的第二个答案,以获取最佳解决方案(它是O(n) - 完全排序整个数组并不是绝对必要的)。 - Alex Riley
类似内容:https://dev59.com/IZDea4cB1Zd3GeqPie7k - Wok
4个回答

194

使用np.argpartition,它不会对整个数组进行排序。它只保证第k个元素在排序位置上,所有比它小的元素都会被移动到它之前,因此前k个元素将是最小的k个元素。

import numpy as np

A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
k = 3

idx = np.argpartition(A, k)
print(idx)
# [4 0 7 3 1 2 6 5]

这将返回前k个最小的值。请注意,它们可能不按排序顺序排列。

print(A[idx[:k]])
# [ 0.1  1.   1.5]

要获取前k个最大值,请使用


idx = np.argpartition(A, -k)
# [4 0 7 3 1 2 6 5]

A[idx[-k:]]
# [  9.  17.  17.]

警告:不要重新使用idx = np.argpartition(A, k);A[idx[-k:]]来获取前k个最大值。这并不总是有效的。例如,以下不是x中的前3个最大值:

x = np.array([100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0])
idx = np.argpartition(x, 3)
x[idx[-3:]]
array([ 70,  80, 100])
这里是与np.argsort的比较,它也可以工作,但只是对整个数组进行排序以获得结果。
In [2]: x = np.random.randn(100000)

In [3]: %timeit idx0 = np.argsort(x)[:100]
100 loops, best of 3: 8.26 ms per loop

In [4]: %timeit idx1 = np.argpartition(x, 100)[:100]
1000 loops, best of 3: 721 µs per loop

In [5]: np.alltrue(np.sort(np.argsort(x)[:100]) == np.sort(np.argpartition(x, 100)[:100]))
Out[5]: True

有什么想法来解决平局吗?如果您想要随机解决平局,似乎唯一可能的方法是使用lexsort对整个数组进行排序。https://dev59.com/cmIj5IYBdhLWcg3wflMy#20199459分区文档说明introselect不稳定,但我不确定这是否意味着平局是随机解决的。 - user27182
@user27182:根据文档,如果a是一个带有字段的数组(即结构化数组),那么您可以指定一个order或让未指定的字段用于打破平局。因此,如果您将A倒入结构化数组的第一个字段中,然后将随机(平局打破)数字倒入第二个字段中,则可以使用np.argpartition选择具有随机平局打破的k个最小值(或最大值)。 - unutbu
3
请记住,前 k-1 个元素的顺序不一定是从最小到最大的。如果这是您需要的内容,可以使用 np.argpartition 函数,在数组上切片使用前 k 个索引,然后在结果数组上使用 np.argsort 函数。 - Leland Hepworth

31

您可以使用 numpy.argsort 与切片

>>> import numpy as np
>>> A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
>>> np.argsort(A)[:3]
array([4, 0, 7], dtype=int32)

3
谢谢!但是如果必须计算所有argsort,然后才保留前k (<=10)个值,这样不会非常慢吗? - Basj
很难说不知道argsort的实现方式。具体来说,如果它被实现为生成器,并且根据实际的排序算法,它可能是惰性的,或者它可能首先对整个集合进行排序,我不确定。 - Cory Kramer
4
从其他的评论中可以看出,argsort 对整个集合进行排序,所以我更喜欢使用其他建议中使用 argpartition 的解决方案。 - Cory Kramer
1
argpartition 相比,这种解决方案的好处在于我们可以保证所寻找的 k 个索引是按升序排列的。 - twink_ml

6
对于n维数组,这个函数很有效。索引以可调用的形式返回。如果你想要返回一个索引列表,则需要在生成列表之前对数组进行转置。
要检索前k个最大值,只需传入-k即可。
def get_indices_of_k_smallest(arr, k):
    idx = np.argpartition(arr.ravel(), k)
    return tuple(np.array(np.unravel_index(idx, arr.shape))[:, range(min(k, 0), max(k, 0))])
    # if you want it in a list of indices . . . 
    # return np.array(np.unravel_index(idx, arr.shape))[:, range(k)].transpose().tolist()

例子:

r = np.random.RandomState(1234)
arr = r.randint(1, 1000, 2 * 4 * 6).reshape(2, 4, 6)

indices = get_indices_of_k_smallest(arr, 4)
indices
# (array([1, 0, 0, 1], dtype=int64),
#  array([3, 2, 0, 1], dtype=int64),
#  array([3, 0, 3, 3], dtype=int64))

arr[indices]
# array([ 4, 31, 54, 77])

%%timeit
get_indices_of_k_smallest(arr, 4)
# 17.1 µs ± 651 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

我一直在寻找n维情况的答案,而你的解决方案正如预期一样有效!此外,它是一个快速的解决方案。 - emil

0

numpy.partition(your_array, k) 是一种替代方法。它不需要切片,因为它会将值排序直到第 k 个元素。


1
这将位于索引k(可能未排序)的元素放在排序位置。由于排序位置不一定是索引kk-1,因此我们无法保证numpy.partitionyour_array[:k]包含前k个最小的元素。 - twink_ml
这是2019年10月针对数组值(而非索引)的最佳答案。@protagonist 我不理解你的评论。如果我错了,请纠正我,但证明此分区函数正确工作的方法是在循环中运行以下内容:y = np.arange(10) ; np.random.shuffle(y) ; y.partition(3) ; assert y[:3+1].max() < y[3+1:].min() ... 分区函数在旧版numpy中是否有不同的行为?另外FYI:k是从零开始计数的。 - Alex Gaudio

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接