Numba JIT带有参数的函数比纯Python慢。

3

我刚刚写了一个简单的基准测试,比较了Numba和Julia,并进行了一些讨论。链接

我想知道是否可以修复我的Numba代码,或者我尝试做的确实不受Numba支持。

我的想法是使用即时编译的数值积分法来评估这个函数。

g(p) = integrate exp(p*x) with respect to x

这是简单的正交函数:

@nb.njit   
def quad_trap(f,a,b,N):
    h = (b-a)/N
    integral = h * ( f(a) + f(b) ) / 2
    for k in range(N):
        xk = (b-a) * k/N + a
        integral = integral + h*f(xk)
    return integral

我可以将一个JIT编译的函数传递给这个函数,就像这个函数:

@nb.njit(nb.float64(nb.float64))
def func(x):
    return math.exp(x) - 10

这比纯Python快10-20倍,相当不错。

现在,我想做的是传递一个关于x和参数p的函数,类似于:

def g(p): 
    @nb.njit(nb.float64(nb.float64))
    def integrand(x):
        return math.exp(p*x) - 10
    return quad_trap(integrand, -1, 1, 10000) 

做这个似乎会导致Numba崩溃,甚至比纯Python更慢。

我是在做错什么,还是这个功能确实不受Numba支持?(我已经检查了文档,但我不确定问题具体在哪里。)谢谢!

1个回答

4

TL;DR:目前Numba似乎还不支持此功能。

这比纯Python快10-20倍,效果很好。

Numba函数quad_trap将在首次调用时被编译。如果参数类型更改,则Numba将重新编译该函数。编译时间通常远非微不足道(几毫秒至几秒)。为避免这种情况,通常的解决方案是指定参数的类型。但是据我所知,在这里(至少没有记录)由于函数的原因,这是不可能的。也就是说,因为您肯定使用相同的函数对quad_trap函数进行基准测试,所以Numba不应重新编译该函数,因为提供的参数的类型未更改。

执行似乎会破坏Numba,即使与纯Python相比,Numba表现得非常慢。

在最近的Numba版本中,它可以正常工作而不发出警告,但这是因为需要对函数integrand进行重复编译,因为Numba不知道其代码是否已更改(或递归地调用了其中的函数/运算符)。在较旧的版本中,Numba可能会抱怨函数integrand读取从其父封闭函数读取的参数p。这称为闭包。

由于闭包需要从它们父函数的堆栈中读取变量,所以通常情况下编译器支持较差。一个普遍存在的问题是,闭包可能会逃离其父函数的作用域并在外部调用,导致未定义的行为(因为闭包将尝试读取完成函数的已过时堆栈)。

一种技巧是将@nb.njit装饰器从integrand移动到g,但是Numba拒绝编译g,因为它不支持可以逃离其父函数作用域的闭包(由于先前描述的问题)。请注意,在您的情况下,闭包没有逃出定义它的函数,但Numba无法证明这一点(因为quad_trap函数已经编译),而且当函数quad_trap被内联时,它也不幸无法证明这一点(尽管从理论上讲,它可以证明这是安全的)。实际上,文档中指出:

Numba现在支持内部函数,只要它们是非递归的并且仅在本地调用,但不作为参数传递或作为结果返回。还支持在内部函数中使用闭包变量(在外部范围定义的变量)。

我认为使用 @generated_jit 装饰器可能有助于解决这个问题,但是我没有成功地在您的特定情况下使其工作。它至少应该帮助在定义时(例如 integrand)编译 g,而不是在首次调用时进行编译。
一种解决方案是简单地不使用闭包:
@nb.njit
def quad_trap_p(f,a,b,N,p):
    h = (b-a)/N
    integral = h * ( f(a,p) + f(b,p) ) / 2
    for k in range(N):
        xk = (b-a) * k/N + a
        integral = integral + h*f(xk,p)
    return integral

@nb.njit(nb.float64(nb.float64, nb.float64))
def integrand(x, p):
    return math.exp(p*x) - 10

def g(p):
    return quad_trap_p(integrand, -1, 1, 10000, p)

谢谢Jérôme,这真的为我澄清了问题!所以Numba每次调用函数时都会重新编译它。哦,你指出的内部函数那一段现在非常清楚了...所以我想可以说,在Numba中这是可能的,但该功能(更好地处理闭包)尚未实现。 - Martín Maas

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接