在numpy数组中找到第n小的元素。

27

我需要在一个一维的 numpy.array 中找到最小的第n个元素。

例如:

a = np.array([90,10,30,40,80,70,20,50,60,0])

我希望得到第五小的元素,所以我的期望输出是40

目前我的解决方案如下:

result = np.max(np.partition(a, 5)[:5])

然而,对我来说找到5个最小元素,然后从它们中选出最大的那个似乎有点笨拙。是否有更好的方法呢?我错过了一个能实现我的目标的函数吗?

有类似标题的问题,但我没有看到任何回答我的问题。

编辑:

我本应该在原文中提到这一点,但性能对我非常重要;因此,heapq 的解决方案虽然不错,但对我无法使用。

import numpy as np
import heapq

def find_nth_smallest_old_way(a, n):
    return np.max(np.partition(a, n)[:n])

# Solution suggested by Jaime and HYRY    
def find_nth_smallest_proper_way(a, n):
    return np.partition(a, n-1)[n-1]

def find_nth_smallest_heapq(a, n):
    return heapq.nsmallest(n, a)[-1]
#    
n_iterations = 10000

a = np.arange(1000)
np.random.shuffle(a)

t1 = timeit('find_nth_smallest_old_way(a, 100)', 'from __main__ import find_nth_smallest_old_way, a', number = n_iterations)
print 'time taken using partition old_way: {}'.format(t1)    
t2 = timeit('find_nth_smallest_proper_way(a, 100)', 'from __main__ import find_nth_smallest_proper_way, a', number = n_iterations)
print 'time taken using partition proper way: {}'.format(t2) 
t3 = timeit('find_nth_smallest_heapq(a, 100)', 'from __main__ import find_nth_smallest_heapq, a', number = n_iterations)  
print 'time taken using heapq : {}'.format(t3)

结果:

time taken using partition old_way: 0.255564928055
time taken using partition proper way: 0.129678010941
time taken using heapq : 7.81094002724

另外,查看http://docs.python.org/2/library/heapq.html可能会有所帮助。 - C.B.
2
@C.B. 上面的问题与我的问题显然有很大不同;它要求找到最小值和最大值,而且是针对二维矩阵的。 - Akavall
4
这怎么是一个重复的问题?标题听起来相似,但问题本身非常不同。有时候不同的问题会得到相同的答案,但这里的答案也非常不同。而且那个问题中的任何答案都不可能是我的问题的答案。 - Akavall
3个回答

43

除非我漏掉了什么,你想要做的是:

>>> a = np.array([90,10,30,40,80,70,20,50,60,0])
>>> np.partition(a, 4)[4]
40

np.partition(a, k) 会将 a 中第 k+1 小的元素放置在 a[k],比 a[k] 小的值在 a[:k] 中,比 a[k] 大的值在 a[k+1:] 中。唯一需要注意的是,由于索引从0开始计数,因此第五个元素的索引为4。


是的,就是这样。我想错了。我知道有更好的解决方案! - Akavall
2
它应该是 np.partition(a, 4)[3]。 - heroxbd
好的,第五个元素。 - heroxbd
1
发现k必须大于或等于方括号[]中的数字。否则会得到错误的答案(我期望它会是一个错误)。我留下这个评论是为了防止有人滥用它以获取错误的答案。 - Isaac Sim
"np.partition(a, 4)[3]" 这个表达式通常不是你想要的 -- 元素 [4] 肯定是第五小的。而元素 [3] 只是其中四个最小值之一,但你不知道它是哪一个。 - Carl Walsh

5
你可以使用 heapq.nsmallest 函数:
>>> import numpy as np
>>> import heapq
>>> 
>>> a = np.array([90,10,30,40,80,70,20,50,60,0])
>>> heapq.nsmallest(5, a)[-1]
40

3
请检查一下你的表现。最近我遇到这样一种情况,使用heapq.nsmallest看起来很完美,但是使用排序并切片的方式实际上是快了大约25%。我相信堆的方法对于某些数据来说更快,但并非所有数据都适用。我不知道是否有关于NumPy数组方面的特殊情况会导致其中一种方法更好。 - Peter DeGlopper
@PeterDeGlopper 嗯,对于较小的数据集,排序方法可能更快,但对于较大的数据集,堆方法应该更快。您所提到的数据有多大? - arshajii
我在原帖中提供的解决方案是O(n),因为np.partitionnp.max都是O(n)。 - Akavall
我见过一些实例,实际上使用 heapify 加上 n 次 heappop 操作比使用 nsmallest 或者切片的 sorted 更快。只是提醒一下。 - roippi
@Akavall 这也是O(n)。堆可以在O(n)的时间内构建,然后5个弹出操作是常数时间。我很惊讶你目前的方法如此之快。 - arshajii
显示剩余2条评论

2
你不需要调用 numpy.max():
def nsmall(a, n):
    return np.partition(a, n)[n]

1
它应该是 np.partition(a, n)[n-1]。 - heroxbd
1
@heroxbd,应该是np.partition(a, n)[n],而不是np.partition(a, n)[n-1]。numpy 1.24的文档中写道: "元素的第k个值将在其最终排序位置上,并且所有较小的元素都将被移动到它之前,所有相等或更大的元素都将被移动到它之后。" 参数kth=0是有效的值,它会导致函数将最小的元素放在开头,因此np.partition(a, kth=0)[0]对于其他n>0同样适用。 - Charlie

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接