获取1s
的索引,命名为idx
,然后使用它来索引a
,得到max
的索引,最后通过索引到idx
来跟踪回原始的顺序 -
idx = np.flatnonzero(label==1)
out = idx[a[idx].argmax()]
样例运行 -
# Assuming inputs to be 1D
In [18]: a
Out[18]: array([1, 2, 3, 4, 5, 4, 3, 2, 1])
In [19]: label
Out[19]: array([1, 0, 1, 0, 0, 1, 1, 0, 1])
In [20]: idx = np.flatnonzero(label==1)
In [21]: idx[a[idx].argmax()]
Out[21]: 5
对于作为整数的a
和作为0s
和1s
数组的label
,我们可以进一步优化,因为我们可以根据a
中的值的范围来缩放它,如下所示 -
(label*(a.max()-a.min()+1) + a).argmax()
此外,如果
a
只有正数,则简化为 -
(label*(a.max()+1) + a).argmax()
正整数较大的a
的计时 -
In [115]: np.random.seed(0)
...: a = np.random.randint(0,10,(100000))
...: label = np.random.randint(0,2,(100000))
In [117]: %%timeit
...: idx = np.flatnonzero(label==1)
...: out = idx[a[idx].argmax()]
1000 loops, best of 3: 592 µs per loop
In [116]: %timeit (label*(a.max()-a.min()+1) + a).argmax()
1000 loops, best of 3: 357 µs per loop
In [120]: %timeit np.ma.masked_where(~label.astype(bool), a).argmax()
1000 loops, best of 3: 1.63 ms per loop
In [119]: %timeit (label*(a.max()+1) + a).argmax()
1000 loops, best of 3: 292 µs per loop
In [121]: %timeit np.argmax(a * (label == 1))
1000 loops, best of 3: 229 µs per loop