numpy.argmin 用于查找大于阈值的元素中最小值的索引

14

我希望找到一个满足特定条件(在我的例子中为中间阈值)的一维NumPy数组中最小值的位置。例如:

import numpy as np

limit = 3
a = np.array([1, 2, 4, 5, 2, 5, 3, 6, 7, 9, 10])
我希望有效地屏蔽所有低于限制值的a中的数字,以使np.argmin的结果为6。是否有一种计算便宜的方法来屏蔽不符合条件的值,然后应用np.argmin

1
你能解释一下为什么在你的问题中np.argmin是6吗?在这种情况下应该是0。如果你把所有小于3的数字遮盖掉,那么你会得到[4,5,5,3,6,7,9,10]。但这个数组的np.argmin仍然不是6。 - OneRaynyDay
@OneRaynyDay 我的猜测是,OP所考虑的掩码数组是 [--, --, 4, 5, --, 5, 3, 6, 7, 9, 10]。然后最小元素是3,位于掩码数组的第6个位置(从0开始计数)。这就是MaxPowers答案中的情况。 - Qaswed
2个回答

20
你可以存储有效的索引,并将其用于从a中选择有效元素,并使用所选元素中的argmin()进行索引以获取最终索引输出。因此,实现应该类似于这样 -
valid_idx = np.where(a >= limit)[0]
out = valid_idx[a[valid_idx].argmin()]

运行示例:

In [32]: limit = 3
    ...: a = np.array([1, 2, 4, 5, 2, 5, 3, 6, 7, 9, 10])
    ...: 

In [33]: valid_idx = np.where(a >= limit)[0]

In [34]: valid_idx[a[valid_idx].argmin()]
Out[34]: 6

运行时测试 -

为了进行性能基准测试,本节将比较基于掩码数组的其他解决方案和此帖子中先前提出的基于常规数组的解决方案在各种数据大小上的表现。

def masked_argmin(a,limit): # Defining func for regular array based soln
    valid_idx = np.where(a >= limit)[0]
    return valid_idx[a[valid_idx].argmin()]

In [52]: # Inputs
    ...: a = np.random.randint(0,1000,(10000))
    ...: limit = 500
    ...: 

In [53]: %timeit np.argmin(np.ma.MaskedArray(a, a<limit))
1000 loops, best of 3: 233 µs per loop

In [54]: %timeit masked_argmin(a,limit)
10000 loops, best of 3: 101 µs per loop

In [55]: # Inputs
    ...: a = np.random.randint(0,1000,(100000))
    ...: limit = 500
    ...: 

In [56]: %timeit np.argmin(np.ma.MaskedArray(a, a<limit))
1000 loops, best of 3: 1.73 ms per loop

In [57]: %timeit masked_argmin(a,limit)
1000 loops, best of 3: 1.03 ms per loop

11

这可以通过使用numpy的MaskedArray轻松实现

import numpy as np

limit = 3
a = np.array([1, 2, 4, 5, 2, 5, 3, 6, 7, 9, 10])
b = np.ma.MaskedArray(a, a<limit)
np.ma.argmin(b)    # == 6

2
为什么不使用 b.argmin() - Carlos Vega
2
另一种可能性是:b = np.ma.masked_less(a, limit) - Alexander Pozdneev

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接