在一个numpy数组中找到n个最小值

3

这里有很多问题,其中一个想在numpy数组中找到第n个最小的元素。但是,如果您有一个数组的数组呢?如下所示:

>>> print matrix
[[ 1.          0.28958002  0.09972488 ...,  0.46999924  0.64723113
   0.60217694]
 [ 0.28958002  1.          0.58005657 ...,  0.37668355  0.48852272
   0.3860152 ]
 [ 0.09972488  0.58005657  1.         ...,  0.13151364  0.29539992
   0.03686381]
 ..., 
 [ 0.46999924  0.37668355  0.13151364 ...,  1.          0.50250212
   0.73128971]
 [ 0.64723113  0.48852272  0.29539992 ...,  0.50250212  1.          0.71249226]
 [ 0.60217694  0.3860152   0.03686381 ...,  0.73128971  0.71249226  1.        ]]

如何从这个二维数组中获取前n小的元素?
>>> print type(matrix)
<type 'numpy.ndarray'>

以下是我找到最小项坐标的方法:

min_cordinates = []
for i in matrix:
    if numpy.any(numpy.where(i==numpy.amin(matrix))[0]):
        min_cordinates.append(int(numpy.where(i==numpy.amin(matrix))[0][0])+1)

现在我想找到前10个最小的项目。

3个回答

6

将矩阵展开,排序后选择前10个。

print(numpy.sort(matrix.flatten())[:10])

1
你可以使用numpy.sort(matrix, axis=None)[:10]代替调用matrix.flatten() - Warren Weckesser

6
如果您的数组不是很大,那么接受的答案就可以了。但是对于大型数组,np.partition 可以更有效地完成此操作。这里有一个示例,该数组有10000个元素,您想找到最小的10个值。
In [56]: np.random.seed(123)

In [57]: a = 10*np.random.rand(100, 100)

使用 np.partition 获取前 10 个最小值:
In [58]: np.partition(a, 10, axis=None)[:10]
Out[58]: 
array([ 0.00067838,  0.00081888,  0.00124711,  0.00120101,  0.00135942,
        0.00271129,  0.00297489,  0.00489126,  0.00556923,  0.00594738])

请注意,这些值并不是按递增顺序排列的。 np.partition不能保证前10个值已排序。如果您需要按递增顺序排列它们,可以在选择的值之后对其进行排序。这仍然比对整个数组进行排序要快。
以下是使用np.sort的结果:
In [59]: np.sort(a, axis=None)[:10]
Out[59]: 
array([ 0.00067838,  0.00081888,  0.00120101,  0.00124711,  0.00135942,
        0.00271129,  0.00297489,  0.00489126,  0.00556923,  0.00594738])

现在比较一下时间:
In [60]: %timeit np.partition(a, 10, axis=None)[:10]
10000 loops, best of 3: 75.1 µs per loop

In [61]: %timeit np.sort(a, axis=None)[:10]
1000 loops, best of 3: 465 µs per loop

在这种情况下,使用np.partition的速度要快六倍以上。

3
你可以使用heapq.nsmallest函数获取前10个最小元素的列表。
In [84]: import heapq

In [85]: heapq.nsmallest(10, matrix.flatten())
Out[85]: 
[-1.7009047695355393,
 -1.4737632239971061,
 -1.1246243781838825,
 -0.7862983016935523,
 -0.5080863016259798,
 -0.43802651199959347,
 -0.22125698200832566,
 0.034938408281615596,
 0.13610084041121048,
 0.15876389111565958]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接