Numba代码比纯Python代码运行慢

20

我一直在努力加速粒子滤波器的重新采样计算。由于Python有许多加速方法,因此我想尝试它们全部。不幸的是,Numba版本非常慢。由于Numba应该导致加速,我认为这是我的错误。

我尝试了4种不同的版本:

  1. Numba
  2. Python
  3. Numpy
  4. Cython

每个版本的代码如下:

import numpy as np
import scipy as sp
import numba as nb
from cython_resample import cython_resample

@nb.autojit
def numba_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def python_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def numpy_resample(qs, xs, rands):
    results = np.empty_like(qs)
    lookup = sp.cumsum(qs)
    for j, key in enumerate(rands):
        i = sp.argmax(lookup>key)
        results[j] = xs[i]
    return results

#The following is the code for the cython module. It was compiled in a
#separate file, but is included here to aid in the question.
"""
import numpy as np
cimport numpy as np
cimport cython

DTYPE = np.float64

ctypedef np.float64_t DTYPE_t

@cython.boundscheck(False)
def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs, 
             np.ndarray[DTYPE_t, ndim=1] xs, 
             np.ndarray[DTYPE_t, ndim=1] rands):
    if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:
        raise ValueError("Arrays must have same shape")
    assert qs.dtype == xs.dtype == rands.dtype == DTYPE

    cdef unsigned int n = qs.shape[0]
    cdef unsigned int i, j 
    cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)
    cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results
"""

if __name__ == '__main__':
    n = 100
    xs = np.arange(n, dtype=np.float64)
    qs = np.array([1.0/n,]*n)
    rands = np.random.rand(n)

    print "Timing Numba Function:"
    %timeit numba_resample(qs, xs, rands)
    print "Timing Python Function:"
    %timeit python_resample(qs, xs, rands)
    print "Timing Numpy Function:"
    %timeit numpy_resample(qs, xs, rands)
    print "Timing Cython Function:"
    %timeit cython_resample(qs, xs, rands)

这将导致以下输出:

Timing Numba Function:
1 loops, best of 3: 8.23 ms per loop
Timing Python Function:
100 loops, best of 3: 2.48 ms per loop
Timing Numpy Function:
1000 loops, best of 3: 793 µs per loop
Timing Cython Function:
10000 loops, best of 3: 25 µs per loop

有没有想法为什么Numba代码运行这么慢?我原以为它至少和Numpy一样快。

注:如果有人有任何加速Numpy或Cython代码示例的想法,那也很好:)但我的主要问题是关于Numba的。


我认为这个更好的地方是http://codereview.stackexchange.com/。 - kylieCatt
1
尝试使用更大的列表? - Joran Beasley
2
@IanAuld:也许吧,但其他人从numba中获得了实质性的加速,我想这是因为我使用不当,而不仅仅是一个分析问题。在我看来,这似乎符合stackoverflow的预期用途。 - jiminy_crist
你有一块合适的GPU吗?可以兼容Numba吗?(我不确定要求是什么)顺便说一句,今天早上我刚看了这个视频http://www.youtube.com/watch?v=iYAG6I433gQ...它可能会有所启发(也可能没有)。 - Joran Beasley
1
作为一条注释,argmax 可以接受一个轴参数,因此您可以将 randslookup 广播到彼此,以制作一个 n x n 矩阵,用于 N^2 缩放算法。或者,您可以使用 searchsorted,它应该具有 Nlog(N) 缩放。 - Daniel
显示剩余2条评论
2个回答

25
问题在于 numba 无法感知 lookup 的类型。如果您在方法中输入print nb.typeof(lookup),您会发现 numba 将其视为对象处理,这会导致速度变慢。通常情况下,我会在本地字典中定义lookup的类型,但是我遇到了一个奇怪的错误。因此,我创建了一个小包装器,以便能够显式定义输入和输出类型。
@nb.jit(nb.f8[:](nb.f8[:]))
def numba_cumsum(x):
    return np.cumsum(x)

@nb.autojit
def numba_resample2(qs, xs, rands):
    n = qs.shape[0]
    #lookup = np.cumsum(qs)
    lookup = numba_cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

那么我的时间安排是:

print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)

print "Timing Revised Numba Function:"
%timeit numba_resample2(qs, xs, rands)

Timing Numba Function:
100 loops, best of 3: 8.1 ms per loop
Timing Revised Numba Function:
100000 loops, best of 3: 15.3 µs per loop

如果你使用jit而不是autojit,速度甚至可以更快:

@nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))

对我来说,这将时间从15.3微秒降至12.5微秒,但autojit的表现仍然令人印象深刻。


是的,问题已经解决了!我尝试了一下在numba_cumsum函数上展开循环并进行jit编译,但要么运行速度更慢,要么无法编译。看起来这已经是最快的了。奇怪的是,现在numba版本的运行速度大约是cython代码的两倍。由于它们都是编译过的,我觉得这很奇怪。你有什么想法吗? - jiminy_crist
@jammycrisp - 我也尝试手动编写cumsum,但我发现它比调用numpy稍微慢一些。至于cython和numba之间的差异,可能与您使用的c编译器有关,而不是llvm。您使用的是哪个编译器?在您的setup.py中指定了任何优化标志吗? - JoshAdel
我正在使用GCC 4.6.3。我不知道你可以在setup.py中添加编译器标志,但是在弄清楚后,我使用-O3编译,但似乎没有改变任何东西。 - jiminy_crist

3

更快的 numpy 版本(与 numpy_resample 相比速度提高了10倍)

def numpy_faster(qs, xs, rands):
    lookup = np.cumsum(qs)
    mm = lookup[None,:]>rands[:,None]
    I = np.argmax(mm,1)
    return xs[I]

谢谢。我想到了一种方法来做这件事,但在跳到Cython之前没有深入研究过它。对于n=100,使用此方法仅获得旧的numpy函数的2倍加速,但这很好知道。不过我仍然很好奇为什么我的numba代码不起作用。 - jiminy_crist

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接