加速计算的一种方法是使用numba
,这是一个针对Python的即时编译器。
@jit
修饰符
Numba提供@jit
修饰符来编译一些Python代码并输出优化的机器代码,可在多个CPU上并行运行。只需很少的努力就可以将积分函数JIT编译,代码经过优化后可以更快地运行。甚至不必担心类型问题,Numba会在幕后处理所有这些。
from scipy import integrate
from numba import jit
@jit
def circular_jit(x, y, a):
if x**2 + y**2 < a**2 / 4:
return 1
else:
return 0
a = 4
result = integrate.nquad(circular_jit, [[-a/2, a/2],[-a/2, a/2]], args=(a,))
这确实运行得更快,我在我的机器上计时,得到如下结果:
Original circular function: 1.599048376083374
Jitted circular function: 0.8280022144317627
这相当于计算时间减少了约50%。
Scipy的LowLevelCallable
由于Python语言的本质,使用Python进行函数调用需要消耗大量时间。 这种开销有时会使Python代码相对于像C这样的编译语言而言变得比较慢。
为了缓解这种情况,Scipy提供了一个LowLevelCallable
类,可用于提供对低级编译回调函数的访问。 通过这种机制,可以绕过Python的函数调用开销,并进一步节省时间。
请注意,在nquad
的情况下,传递给LowerLevelCallable
的cfunc
签名必须是以下之一:
double func(int n, double *xx)
double func(int n, double *xx, void *user_data)
在这里,int
是参数数量,参数值在第二个参数中给出。user_data
用于需要上下文才能操作的回调函数。
因此,我们可以稍微改变Python的循环函数签名以使其兼容。
from scipy import integrate, LowLevelCallable
from numba import cfunc
from numba.types import intc, CPointer, float64
@cfunc(float64(intc, CPointer(float64)))
def circular_cfunc(n, args):
x, y, a = (args[0], args[1], args[2])
if x**2 + y**2 < a**2/4:
return 1
else:
return 0
circular_LLC = LowLevelCallable(circular_cfunc.ctypes)
a = 4
result = integrate.nquad(circular_LLC, [[-a/2, a/2],[-a/2, a/2]], args=(a,))
使用这种方法,我得到了:
LowLevelCallable circular function: 0.07962369918823242
与原始版本相比,这是一个95%的缩减,与该功能的即时编译版本相比则为90%。
一个定制化的装饰器
为了使代码更整洁并保持被积函数的签名灵活性,可以创建一个定制化的装饰器函数。它将即时编译被积函数并将其封装到一个LowLevelCallable
对象中,然后可以与nquad
一起使用。
from scipy import integrate, LowLevelCallable
from numba import cfunc, jit
from numba.types import intc, CPointer, float64
def jit_integrand_function(integrand_function):
jitted_function = jit(integrand_function, nopython=True)
@cfunc(float64(intc, CPointer(float64)))
def wrapped(n, xx):
return jitted_function(xx[0], xx[1], xx[2])
return LowLevelCallable(wrapped.ctypes)
@jit_integrand_function
def circular(x, y, a):
if x**2 + y**2 < a**2 / 4:
return 1
else:
return 0
a = 4
result = integrate.nquad(circular, [[-a/2, a/2],[-a/2, a/2]], args=(a,))
任意数量的参数
如果参数数量未知,则可以使用Numba提供的方便的carray
函数将CPointer(float64)
转换为Numpy数组。
import numpy as np
from scipy import integrate, LowLevelCallable
from numba import cfunc, carray, jit
from numba.types import intc, CPointer, float64
def jit_integrand_function(integrand_function):
jitted_function = jit(integrand_function, nopython=True)
@cfunc(float64(intc, CPointer(float64)))
def wrapped(n, xx):
ar = carray(xx, n)
return jitted_function(ar[0], ar[1], ar[2:])
return LowLevelCallable(wrapped.ctypes)
@jit_integrand_function
def circular(x, y, a):
if x**2 + y**2 < a[-1]**2 / 4:
return 1
else:
return 0
ar = np.array([1, 2, 3, 4])
a = ar[-1]
result = integrate.nquad(circular, [[-a/2, a/2],[-a/2, a/2]], args=ar)