Numba在nopython模式下的函数线性组合

3
我一直在尝试使用Numba进行函数的自动生成/即时编译。您可以在jit函数内部调用其他jit函数,因此如果您有一组特定的函数,那么硬编码所需的功能就很容易,如下所示:
from numba import jit
@jit(nopython=True)
def f1(x):
    return (x - 2.0)**2

@jit(nopython=True)
def f2(x):
    return (x - 5.0)**2

def hardcoded(x, c):
    @jit(nopython=True)
    def f(x):
        return c[0] * f1(x) + c[1] * f2(x)
    return f

lincomb = hardcoded(3, (0.5, 0.5))
print(lincomb(2))

Out: 4.5

然而,假设您事先不知道f1、f2是什么。我希望能够使用工厂生成这些函数,然后再生成它们的线性组合:

def f_factory(x0):
    @jit(nopython=True)
    def f(x):
        return (x - x0)**2
    return f

def linear_comb(funcs, coeffs, nopython=True):
    @jit(nopython=nopython)
    def lc(x):
        total = 0.0
        for f, c in zip(funcs, coeffs):
            total += c * f(x)
        return total
    return lc

并在运行时调用。这在不使用nopython模式的情况下有效:

funcs = (f_factory(2.0), f_factory(5.0))
lc = linear_comb(funcs, (0.5, 0.5), nopython=False)
print(lc(2))

Out: 4.5

但是这不适用于nopython模式。
lc = linear_comb(funcs, (0.5, 0.5), nopython=True)
print(lc(2))

TypingError: Failed at nopython (nopython frontend)
Untyped global name 'funcs': cannot determine Numba type of <class 'tuple'>
File "<ipython-input-100-2d3fb6214044>", line 11

看起来 Numba 在 jit 函数的元组上存在问题。有没有办法让这种行为正常工作?

由于函数集和 c 可能会很大,因此我真的希望以 nopython 模式编译它。


有没有什么理由不将 x0 设为 f 的参数并移除这个因子? - chrisb
在实际代码中,f 可以更加任意。如果 x0 是函数的参数,则线性组合需要知道每个 f 的参数,这些参数可能是不同的。例如,我想能够说,有 f(x, x0, alpha, beta) 和 f2(x, x0, bool_flag),并且具有封装闭包,使它们看起来像 f(x)。 - evamicur
明白了,我认为目前限制条件下,你可能没有什么运气,除非进行一些非常丑陋的代码生成。 - chrisb
我也有同样的担忧。或许有一种方法可以为JIT函数生成名称,然后通过名称调用它们来进行代码生成,这不是最糟糕的事情。 - evamicur
1个回答

0

可能有更好的方法来解决这个问题,但作为一种巧妙的解决方法,您确实可以使用一些“模板”来生成元组中每个函数的唯一名称和调用。

def linear_comb(funcs, coeffs, nopython=True):
    scope = {'coeffs': coeffs}

    stmts = [
    'def lc(x):',
    '    total = 0.0',
    ]
    for i, f in enumerate(funcs):
        # give each function a unique name
        scope[f'_f{i}'] = f
        # codegen for total line
        stmts.append(f'    total += coeffs[{i}] * _f{i}(x)')
    stmts.append('    return total')

    code = '\n'.join(stmts)
    exec(code, scope)
    lc = jit(nopython=nopython)(scope['lc'])
    return lc

lc = linear_comb(funcs, (0.5, 0.5), nopython=True)

lc(2)
Out[103]: 4.5

不错!这对我很有用。我曾经试图想出类似的东西,但没有想到可以这样利用exec。真棒。 - evamicur

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接