Numba jit与Scipy

11

我想通过使用 numba 的 jit 来加速我编写的程序。 然而,由于 scipy 的许多函数使用 try ... except ... 结构,jit 不能处理它们(这一点我说得对吗?)

我想到了一个相对简单的解决方案,即复制我需要的 scipy 源代码并删除其中的 try except 部分(我已经知道它不会遇到错误,所以 try 部分始终有效)

然而,我不喜欢这个解决方案,也不确定它是否有效。

我的代码结构如下:

import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit

def fitfunction():
    ...

@jit
def function(x):
    # do some stuff
    try:
        fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=(0,0,0), maxfev=500)
        for idx in some_list:
            integrated = integrate.quad(lambda x: fitfunction(fit_param), lower, upper)
    except:
        fit_param=(0,0,0)
        ...

现在会导致以下错误:

LoweringError: 在对象模式后端中失败

我认为这是由于 jit 无法处理 try except 所导致的(如果我只在 curve_fitintegrate.quad 部分使用 jit 并绕过自己的 try except 结构,它也不起作用)。

import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit

def fitfunction():
    ...

@jit
def integral(lower, upper):
    return integrate.quad(lambda x: fitfunction(fit_param), lower, upper)

@jit
def fitting(x, y, pzero, max_fev)
    return curve_fit(fitfunction, x, y, p0=pzero, maxfev=max_fev)


def function(x):
    # do some stuff
    try:
        fit_param, fit_cov = fitting(x, y, (0,0,0), 500)
        for idx in some_list:
            integrated = integral(lower, upper)
    except:
        fit_param=(0,0,0)
        ...

有没有一种方法可以在不手动删除scipy代码中的所有try except结构的情况下,将jit与scipy.integrate.quad和curve_fit一起使用?

这样做会加速代码吗?


2
不要试图对scipy函数进行“jit”编译,为什么不集中精力加速你自己的函数fitfunction呢?正是这个函数被quadcurve_fit反复调用。quad已经使用了在_quadpack模块中编译的代码。 - hpaulj
仅在拟合函数中使用jit不会在我的代码中产生任何加速效果。 - Katermickie
2个回答

13
Numba不是用于加速代码的通用库。有一类问题可以使用numba更快地解决(特别是如果您有关于数组、数值计算的循环),但其他所有内容要么不被支持,要么只是略微更快或甚至更慢。
SciPy已经是一个高性能库,所以在大多数情况下,我会预期numba表现更差(或者很少:稍微好一点)。您可以进行一些profiling以找出瓶颈是否真的在您jit的代码中,然后您可以获得一些改进。但我怀疑瓶颈将在SciPy的编译代码中,并且该编译代码可能已经被大量优化(因此几乎不可能找到一个“仅”可以与该代码竞争的实现)。
目前,numba不支持tryexcept,正如您所正确推测的那样。

2.6.1. 语言

2.6.1.1. 结构

Numba 力求支持尽可能多的 Python 语言特性,但是一些语言特性不能在 Numba 编译的函数中使用。目前不支持以下 Python 语言特性:

[...]

  • 异常处理(try .. excepttry .. finally

所以这里的答案是不支持


问题在于,通常我们希望在numba函数内使用scipy。例如,经常会有一个循环充满着数值计算,而numba可以加快速度。但是如果其中一个数值计算调用了一个scipy函数(这是一种常见情况),那么现在就不能使用numba了。 - Nick Alger
@NickAlger 是的,肯定不方便。我并不是说numba现在很棒。然而,这是一个限制,而且通常很容易解决。例如,只有少数情况下我无法将外部函数调用移动到numba函数之外。此外,如果使用外部函数调用,通常甚至不再有必要使用numba,因为与在外部函数中花费的时间相比,numba循环和数值计算性能提升根本微不足道。 - MSeifert
但这真的取决于您的具体情况。无法就此给出一般性回答。如果您需要快速计算和外部函数调用,可以尝试使用Cython。 - MSeifert

7
现在,tryexcept与numba一起使用。但是,numba和scipy仍然不兼容。是的,Scipy调用已编译的C和Fortran代码,但是以numba无法处理的方式进行调用。
幸运的是,有一些替代方案可以很好地与numba配合使用!下面我将使用NumbaQuadpackNumbaMinpack来执行类似于您示例代码的曲线拟合和积分。声明:我组装了这些包。下面,我还提供了一个等效的scipy实现。
Scipy实现的速度慢了约18倍,相比之下,这些Scipy替代方案(NumbaQuadpack和NumbaMinpack)要快得多。

使用Scipy替代方案(0.23毫秒)

from NumbaQuadpack import quadpack_sig, dqags
from NumbaMinpack import minpack_sig, lmdif
import numpy as np
import numba as nb
import timeit
np.random.seed(0)

x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)

@nb.jit
def fitfunction(x, A, B):
    return A*np.sin(B*x)

@nb.cfunc(minpack_sig)
def fitfunction_optimize(u_, fvec, args_):
    u = nb.carray(u_,(2,))
    args = nb.carray(args_,(200,))
    A, B = u
    x = args[:100]
    y = args[100:]
    for i in range(100):
        fvec[i] = fitfunction(x[i], A, B) - y[i] 
optimize_ptr = fitfunction_optimize.address

@nb.cfunc(quadpack_sig)
def fitfunction_integrate(x, data):
    A = data[0]
    B = data[1]
    return fitfunction(x, A, B)
integrate_ptr = fitfunction_integrate.address

@nb.njit
def fast_function():  
    try:
        neqs = 100
        u_init = np.array([2.0,.8],np.float64)
        args = np.append(x,y)
        fitparam, fvec, success, info = lmdif(optimize_ptr , u_init, neqs, args)
        if not success:
            raise Exception

        lower = 0.0
        uppers = np.linspace(np.pi,np.pi*2.0,200)
        solutions = np.empty(len(uppers))
        for i in range(len(uppers)):
            solutions[i], abserr, success = dqags(integrate_ptr, lower, uppers[i], data = fitparam)
            if not success:
                raise Exception
    except:
        print('doing something else')
        
fast_function()
iters = 1000
t_nb = timeit.Timer(fast_function).timeit(number=iters)/iters
print(t_nb)

使用Scipy (4.4毫秒)

import scipy.integrate as integrate
from scipy.optimize import curve_fit
import numpy as np
import numba as nb
import timeit

np.random.seed(0)

x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)

@nb.jit
def fitfunction(x, A, B):
    return A*np.sin(B*x)

def function():
    try:
        p0 = (2.0,.8)
        fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=p0, maxfev=500)

        lower = 0.0
        uppers = np.linspace(np.pi,np.pi*2.0,200)
        solutions = np.empty(len(uppers))
        for i in range(len(uppers)):
            solutions[i], abserr = integrate.quad(fitfunction, lower, uppers[i], args = tuple(fit_param))
    except:
        print('do something else')

function()
iters = 1000
t_sp = timeit.Timer(function).timeit(number=iters)/iters
print(t_sp)

不错!我会仔细看看。 - Katermickie

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接