如何根据条件获取numpy数组的索引

3
我遇到了这样一个问题: 假设我有如下数组: a = np.array([[1,2,3,4,5,4,3,2,1],]) label = np.array([[1,0,1,0,0,1,1,0,1],]) 我需要获取a中元素值最大的那个位置的索引,其中label等于1。 可能有些混淆,在上面的例子中,label为1的索引是:0、2、5、6和8,它们对应的a的值分别是:1、3、4、3和1,其中4是最大的,因此我需要得到数字4在a中的索引5的结果。我该如何使用numpy实现这一点?
3个回答

3

获取1s 的索引,命名为idx,然后使用它来索引a,得到max的索引,最后通过索引到idx来跟踪回原始的顺序 -

idx = np.flatnonzero(label==1)
out = idx[a[idx].argmax()]

样例运行 -

# Assuming inputs to be 1D
In [18]: a
Out[18]: array([1, 2, 3, 4, 5, 4, 3, 2, 1])

In [19]: label
Out[19]: array([1, 0, 1, 0, 0, 1, 1, 0, 1])

In [20]: idx = np.flatnonzero(label==1)

In [21]: idx[a[idx].argmax()]
Out[21]: 5

对于作为整数的a和作为0s1s数组的label,我们可以进一步优化,因为我们可以根据a中的值的范围来缩放它,如下所示 -

(label*(a.max()-a.min()+1) + a).argmax()

此外,如果a只有正数,则简化为 -
(label*(a.max()+1) + a).argmax()

正整数较大的a的计时 -

In [115]: np.random.seed(0)
     ...: a = np.random.randint(0,10,(100000))
     ...: label = np.random.randint(0,2,(100000))

In [117]: %%timeit
     ...: idx = np.flatnonzero(label==1)
     ...: out = idx[a[idx].argmax()]
1000 loops, best of 3: 592 µs per loop

In [116]: %timeit (label*(a.max()-a.min()+1) + a).argmax()
1000 loops, best of 3: 357 µs per loop

# @coldspeed's soln
In [120]: %timeit np.ma.masked_where(~label.astype(bool), a).argmax()
1000 loops, best of 3: 1.63 ms per loop

# won't work with negative numbers in a
In [119]: %timeit (label*(a.max()+1) + a).argmax()
1000 loops, best of 3: 292 µs per loop

# @klim's soln (won't work with negative numbers in a)
In [121]: %timeit np.argmax(a * (label == 1))
1000 loops, best of 3: 229 µs per loop

好答案! :-) - cs95

1
您可以使用掩码数组:
>>> np.ma.masked_where(~label.astype(bool), a).argmax()
5

1
这是其中一种最简单的方法。
>>> np.argmax(a * (label == 1))
5
>>> np.argmax(a * (label == 1), axis=1)
array([5])

Coldspeed的方法可能需要更多时间。

1
如果a中有负数会怎么样? - Divakar
Divakar。如果a中有负数或者label中没有匹配项,那么这个方法将无法工作。 - klim
根据刚刚添加的时间,仍然在限制条件下非常快。 - Divakar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接