我一直在尝试使用Numba进行函数的自动生成/即时编译。您可以在jit函数内部调用其他jit函数,因此如果您有一组特定的函数,那么硬编码所需的功能就很容易,如下所示:
但是这不适用于nopython模式。
from numba import jit
@jit(nopython=True)
def f1(x):
return (x - 2.0)**2
@jit(nopython=True)
def f2(x):
return (x - 5.0)**2
def hardcoded(x, c):
@jit(nopython=True)
def f(x):
return c[0] * f1(x) + c[1] * f2(x)
return f
lincomb = hardcoded(3, (0.5, 0.5))
print(lincomb(2))
Out: 4.5
然而,假设您事先不知道f1、f2是什么。我希望能够使用工厂生成这些函数,然后再生成它们的线性组合:
def f_factory(x0):
@jit(nopython=True)
def f(x):
return (x - x0)**2
return f
def linear_comb(funcs, coeffs, nopython=True):
@jit(nopython=nopython)
def lc(x):
total = 0.0
for f, c in zip(funcs, coeffs):
total += c * f(x)
return total
return lc
并在运行时调用。这在不使用nopython模式的情况下有效:
funcs = (f_factory(2.0), f_factory(5.0))
lc = linear_comb(funcs, (0.5, 0.5), nopython=False)
print(lc(2))
Out: 4.5
但是这不适用于nopython模式。
lc = linear_comb(funcs, (0.5, 0.5), nopython=True)
print(lc(2))
TypingError: Failed at nopython (nopython frontend)
Untyped global name 'funcs': cannot determine Numba type of <class 'tuple'>
File "<ipython-input-100-2d3fb6214044>", line 11
看起来 Numba 在 jit 函数的元组上存在问题。有没有办法让这种行为正常工作?
由于函数集和 c 可能会很大,因此我真的希望以 nopython 模式编译它。
x0
设为f
的参数并移除这个因子? - chrisb