我会提名
a.argmax()
使用@fuglede
的测试数组:
In [1]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [2]: np.isnan(a).argmax()
Out[2]: 9999
In [3]: np.argmax(a)
Out[3]: 9999
In [4]: a.argmax()
Out[4]: 9999
In [5]: timeit a.argmax()
The slowest run took 29.94 ....
10000 loops, best of 3: 20.3 µs per loop
In [6]: timeit np.isnan(a).argmax()
The slowest run took 7.82 ...
1000 loops, best of 3: 462 µs per loop
我没有安装numba
,所以无法进行比较。但相对于short
,我的加速比超过了@fuglede's
的6倍。
我在Py3中进行测试,它接受<np.nan
,而Py2则会引发运行时警告。但代码搜索表明,这并不依赖于该比较。
/numpy/core/src/multiarray/calculation.c
的PyArray_ArgMax
操纵轴(将感兴趣的轴移动到末尾),并将操作委托给arg_func=PyArray_DESCR(ap)->f->argmax
,一个取决于dtype的函数。
在numpy/core/src/multiarray/arraytypes.c.src
中,看起来BOOL_argmax
短路,一旦遇到True
就返回。
for (; i < n; i++) {
if (ip[i]) {
*max_ind = i;
return 0;
}
}
而且@fname@_argmax
在最大的nan
上也会短路。 在argmin
中,np.nan
同样是“最大”的。
#if @isfloat@
if (@isnan@(mp)) {
/* nan encountered; it's maximal */
return 0;
}
#endif
欢迎有经验的C程序员对此进行评论,但在我看来,至少对于np.nan
而言,一个简单的argmax
就足以达到最快的速度。
通过调整生成a
中的9999
值,可以发现a.argmax
的时间取决于该值,与短路效应一致。
dup
地址提供了“短路”的替代方案。isnan
部分并没有使这个问题具有唯一性。但是,dup
已经过时了。 - hpauljnan
出现的位置,而整个数组代码则稳定,仅取决于数组的总长度(即使我已经将长度增加到1000以上才能看到变化)。 - hpaulj