寻找第一个np.nan值的最有效方法是什么?

15

考虑数组a

a = np.array([3, 3, np.nan, 3, 3, np.nan])

我可以做

np.isnan(a).argmax()

但这需要找到所有的np.nan才能找到第一个。
有更有效率的方法吗?


我一直在尝试弄清楚是否可以向np.argpartition传递参数,使得np.nan首先排序而不是最后。


关于[dup]的编辑。
这个问题与其他问题有几个不同之处:

  1. 那些问题和答案涉及值的等式。这是关于isnan的。
  2. 那些答案都面临着我回答所面临的相同问题。请注意,我提供了一个完全有效的答案,但强调了它的低效性。我要解决这个效率问题。

关于第二个[dup]的编辑。

仍然涉及等式,问题和答案都很陈旧,很可能已经过时了。


1
可能是Numpy:快速查找值的第一个索引的重复问题。 - fuglede
1
第二个 dup 地址提供了“短路”的替代方案。 isnan 部分并没有使这个问题具有唯一性。但是,dup 已经过时了。 - hpaulj
Cython/ctypes/JIT 可能是处理大数组“短路”的最佳方式,正如第二个重复显示的那样。 - Imanol Luengo
2
为了使这个公平,您需要设计一个定时框架。短路时间严重取决于第一个nan出现的位置,而整个数组代码则稳定,仅取决于数组的总长度(即使我已经将长度增加到1000以上才能看到变化)。 - hpaulj
4个回答

12

值得一提的是,可以考虑使用numba.jit;没有它,在大多数情况下,向量化版本很可能会打败纯Python搜索的直接实现,但在编译代码后,普通搜索将占据主导地位,至少在我的测试中如此:

In [63]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])

In [70]: %paste
import numba

def naive(a):
        for i in range(len(a)):
                if np.isnan(a[i]):
                        return i

def short(a):
        return np.isnan(a).argmax()

@numba.jit
def naive_jit(a):
        for i in range(len(a)):
                if np.isnan(a[i]):
                        return i

@numba.jit
def short_jit(a):
        return np.isnan(a).argmax()
## -- End pasted text --

In [71]: %timeit naive(a)
100 loops, best of 3: 7.22 ms per loop

In [72]: %timeit short(a)
The slowest run took 4.59 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 37.7 µs per loop

In [73]: %timeit naive_jit(a)
The slowest run took 6821.16 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.79 µs per loop

In [74]: %timeit short_jit(a)
The slowest run took 395.51 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 144 µs per loop

编辑:如@hpaulj在他们的答案中指出的那样,numpy实际上提供了一个优化的短路搜索,其性能与上面JIT编译的搜索相当:

In [26]: %paste
def plain(a):
        return a.argmax()

@numba.jit
def plain_jit(a):
        return a.argmax()
## -- End pasted text --

In [35]: %timeit naive(a)
100 loops, best of 3: 7.13 ms per loop

In [36]: %timeit plain(a)
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 7.04 µs per loop

In [37]: %timeit naive_jit(a)
100000 loops, best of 3: 6.91 µs per loop

In [38]: %timeit plain_jit(a)
10000 loops, best of 3: 125 µs per loop

10

我会提名

a.argmax()

使用@fuglede的测试数组:

In [1]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [2]: np.isnan(a).argmax()
Out[2]: 9999
In [3]: np.argmax(a)
Out[3]: 9999
In [4]: a.argmax()
Out[4]: 9999

In [5]: timeit a.argmax()
The slowest run took 29.94 ....
10000 loops, best of 3: 20.3 µs per loop

In [6]: timeit np.isnan(a).argmax()
The slowest run took 7.82 ...
1000 loops, best of 3: 462 µs per loop

我没有安装numba,所以无法进行比较。但相对于short,我的加速比超过了@fuglede's的6倍。

我在Py3中进行测试,它接受<np.nan,而Py2则会引发运行时警告。但代码搜索表明,这并不依赖于该比较。

/numpy/core/src/multiarray/calculation.cPyArray_ArgMax操纵轴(将感兴趣的轴移动到末尾),并将操作委托给arg_func=PyArray_DESCR(ap)->f->argmax,一个取决于dtype的函数。

numpy/core/src/multiarray/arraytypes.c.src中,看起来BOOL_argmax短路,一旦遇到True就返回。

for (; i < n; i++) {
    if (ip[i]) {
        *max_ind = i;
        return 0;
    }
}

而且@fname@_argmax在最大的nan上也会短路。 在argmin中,np.nan同样是“最大”的。

#if @isfloat@
    if (@isnan@(mp)) {
        /* nan encountered; it's maximal */
        return 0;
    }
#endif

欢迎有经验的C程序员对此进行评论,但在我看来,至少对于np.nan而言,一个简单的argmax就足以达到最快的速度。

通过调整生成a中的9999值,可以发现a.argmax的时间取决于该值,与短路效应一致。


现在我能够清晰地思考了。这太棒了!而且这一切都归结于np.nan是最大的。 - piRSquared
有趣!在我的测试设置中,普通的 argmax 和 JITted 搜索表现一样好。我猜这是有道理的,因为它们显然在做同样的事情!我会将时间添加到其他答案中。 - fuglede

6

这里是一个使用 itertools.takewhile() 的Python方法:

from itertools import takewhile
sum(1 for _ in takewhile(np.isfinite, a))

使用生成器表达式和next方法进行基准测试:1
In [118]: a = np.repeat(a, 10000)

In [120]: %timeit next(i for i, j in enumerate(a) if np.isnan(j))
100 loops, best of 3: 12.4 ms per loop

In [121]: %timeit sum(1 for _ in takewhile(np.isfinite, a))
100 loops, best of 3: 11.5 ms per loop

但仍然(远远)比numpy方法慢:
In [119]: %timeit np.isnan(a).argmax()
100000 loops, best of 3: 16.8 µs per loop

1. 这种方法的问题在于使用了enumerate函数。它首先从numpy数组返回一个enumerate对象(这是一个类似迭代器的对象),调用生成器函数和迭代器的next属性将会花费时间。


3
这个数组太小了,这些方法无法击败NumPy的函数。或许可以尝试使用一个更大的数组? - ayhan
2
@ayhan 实际上这是一个重复的数组。我只是忘记添加相关命令了。这是一个新的。 - Mazdak
这是 itertools.takewhile - hpaulj

3

在不同的场景中寻找第一个匹配项时,我们可以遍历并查找第一个匹配项,并在第一个匹配项处退出,而不是处理整个数组。所以,我们可以使用 Python 的 next 函数 来实现这一方法,就像这样 -

next((i for i, val in enumerate(a) if np.isnan(val)))

样例运行 -

In [192]: a = np.array([3, 3, np.nan, 3, 3, np.nan])

In [193]: next((i for i, val in enumerate(a) if np.isnan(val)))
Out[193]: 2

In [194]: a[2] = 10

In [195]: next((i for i, val in enumerate(a) if np.isnan(val)))
Out[195]: 5

我原本想建议使用带有中断(break)的循环,但这个使用生成器加一个next同样有效。不过我对速度有些担忧。 - hpaulj
3
对于非常大的数组和起始部分有NaN值的情况,使用“迭代至第一个有机会更快”。但通常编译好的NumPy函数将更快,即使它们遍历整个数组。需要哪些测试套件? - hpaulj
我甚至尝试使用了collections中的deque,但它并没有起到帮助作用。 - piRSquared
2
在第一个“重复”中,展示了这个生成器的下一个解决方案。 - hpaulj

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接