1. 通过 scipy.integrate.quad
传递额外参数
quad
文档中写道:
如果用户需要提高积分性能,则 f
可以是一个参数签名为以下之一的 scipy.LowLevelCallable
函数:
double func(double x)
double func(double x, void *user_data)
double func(int n, double *xx)
double func(int n, double *xx, void *user_data)
user_data
是包含在 scipy.LowLevelCallable
中的数据。在带有 xx
的调用形式中,n
是包含于 args
参数并且 xx[0] == x
的数组 xx
的长度。
因此,要通过 quad
传递额外参数给 integrand
,最好使用参数签名为 double func(int n, double *xx)
的函数。
您可以编写装饰器将积分函数转换为 LowLevelCallable
函数,如下所示:
import numpy as np
import scipy.integrate as si
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable
def jit_integrand_function(integrand_function):
jitted_function = numba.jit(integrand_function, nopython=True)
@cfunc(float64(intc, CPointer(float64)))
def wrapped(n, xx):
return jitted_function(xx[0], xx[1])
return LowLevelCallable(wrapped.ctypes)
@jit_integrand_function
def integrand(t, *args):
a = args[0]
return np.exp(-t/a) / t**2
def do_integrate(func, a):
"""
Integrate the given function from 1.0 to +inf with additional argument a.
"""
return si.quad(func, 1, np.inf, args=(a,))
print(do_integrate(integrand, 2.))
>>>(0.326643862324553, 1.936891932288535e-10)
如果您不想使用装饰器,可以手动创建LowLevelCallable
并将其传递给quad
函数。
2. 包装被积函数
我不确定以下方法是否符合您的要求,但您也可以包装integrand
函数以实现相同的结果:
import numpy as np
from numba import cfunc
import numba.types
def get_integrand(*args):
a = args[0]
def integrand(t):
return np.exp(-t/a) / t**2
return integrand
nb_integrand = cfunc(numba.float64(numba.float64))(get_integrand(2.))
import scipy.integrate as si
def do_integrate(func):
"""
Integrate the given function from 1.0 to +inf.
"""
return si.quad(func, 1, np.inf)
print(do_integrate(get_integrand(2)))
>>>(0.326643862324553, 1.936891932288535e-10)
print(do_integrate(nb_integrand.ctypes))
>>>(0.326643862324553, 1.936891932288535e-10)
3. 将 voidptr
转换为 Python 类型
我认为目前还不可能。从2016年的这次讨论来看,voidptr
只是用来将上下文传递给 C 回调函数。
void * 指针的情况是针对那些外部 C 代码不会尝试解引用指针,而只是将其作为回调之间保留状态的一种方式的 API。我认为目前没有特别重要的地方,但我想提出这个问题。
尝试以下操作:
numba.types.RawPointer('p').can_convert_to(
numba.typing.context.Context(), CPointer(numba.types.Any)))
>>>None
看起来也不太令人鼓舞!
def wrapped (n, XX)
中,n 是参数的数量,XX 是参数值的列表/数组。因此,您可以轻松地将上述内容适应于任意数量的参数。 - Jacques Gaudindouble func(double x, void *user_data)
签名来传递numpy数组(包括一些元数据,如形状)和其他内容。例如:https://dev59.com/ibbna4cB1Zd3GeqPjvLs#58561573 - max9111