在numpy中计算数组值超过阈值的最快方法

8

我有一个包含10^8个浮点数的numpy数组,想要统计其中有多少个数大于或等于给定的阈值。由于需要对大量这样的数组进行操作,因此速度至关重要。目前的竞争者有:

np.sum(myarray >= thresh)

np.size(np.where(np.reshape(myarray,-1) >= thresh))

答案在这里,建议使用np.where()会更快,但我发现时间结果不一致。我的意思是对于某些实现和布尔条件,np.size(np.where(cond))比np.sum(cond)更快,但对于某些实现和布尔条件,它却更慢。
具体来说,如果大部分条目都满足条件,则np.sum(cond)明显更快,但如果只有很少的条目(可能少于十分之一)满足条件,则np.size(np.where(cond))获胜。
问题分为两个部分:
  • 还有其他建议吗?
  • 使用np.size(np.where(cond))所需的时间随着cond为真的条目数的增加而增加是否合理?

numexpr或numba可能通过避免创建中间数组来加速处理速度。 - user2357112
1
还有np.count_nonzero,它比新的numpy版本上的布尔和更快。 - seberg
1个回答

3

使用Cython可能是一个不错的选择。

import numpy as np
cimport numpy as np
cimport cython
from cython.parallel import prange


DTYPE_f64 = np.float64
ctypedef np.float64_t DTYPE_f64_t


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef int count_above_cython(DTYPE_f64_t [:] arr_view, DTYPE_f64_t thresh) nogil:

    cdef int length, i, total
    total = 0
    length = arr_view.shape[0]

    for i in prange(length):
        if arr_view[i] >= thresh:
            total += 1

    return total


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def count_above(np.ndarray arr, DTYPE_f64_t thresh):

    cdef DTYPE_f64_t [:] arr_view = arr.ravel()
    cdef int total

    with nogil:
       total =  count_above_cython(arr_view, thresh)
    return total

不同提出的方法的时间安排。

myarr = np.random.random((1000,1000))
thresh = 0.33

In [6]: %timeit count_above(myarr, thresh)
1000 loops, best of 3: 693 µs per loop

In [9]: %timeit np.count_nonzero(myarr >= thresh)
100 loops, best of 3: 4.45 ms per loop

In [11]: %timeit np.sum(myarr >= thresh)
100 loops, best of 3: 4.86 ms per loop

In [12]: %timeit np.size(np.where(np.reshape(myarr,-1) >= thresh))
10 loops, best of 3: 61.6 ms per loop

使用更大的数组:

In [13]: myarr = np.random.random(10**8)

In [14]: %timeit count_above(myarr, thresh)
10 loops, best of 3: 63.4 ms per loop

In [15]: %timeit np.count_nonzero(myarr >= thresh)
1 loops, best of 3: 473 ms per loop

In [16]: %timeit np.sum(myarr >= thresh)
1 loops, best of 3: 511 ms per loop

In [17]: %timeit np.size(np.where(np.reshape(myarr,-1) >= thresh))
1 loops, best of 3: 6.07 s per loop

我猜这将取决于硬件,在cython中,您可以更容易地并行化。在cython中使用-O3(没有它会很慢)和开发numpy,在我的电脑上它们的性能非常接近(稍微优势为cython,但是numpy代码在不连续数组方面要快得多,当然您可以修复它)。但是,您应该真正使用ssize_t/np.intp_t而不是int,否则就是一个错误。 - seberg

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接