使用复数NumPy数组和本地数据类型时出现numba TypingError问题

5

我有一个处理复杂数据类型的函数,并使用 numba 进行快速处理。我使用 numpy 声明一个零数组,带有复杂数据类型,在函数中稍后填充。但是在运行时,numba 无法重载生成零数组的函数。为了重现这个错误,我提供了一个 MWE。

import numpy as np
from numba import njit

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=complex)
    a[idx] = 10
    return a

my_func(4)

在初始化数组a时,出现了以下错误。
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)

No implementation of function Function(<built-in function zeros>) found for signature:
zeros(Tuple(Literal[int](10), Literal[int](5)), dtype=Function(<class 'complex'>))
There are 2 candidate implementations:

 Of which 2 did not match due to:
  Overload of function 'zeros': File: numba\core\typing\npydecl.py: Line 511.
    With argument(s): '(UniTuple(int64 x 2), dtype=Function(<class 'complex'>))':
   No match.

我猜想这可能与变量 a 的数据类型有关(我需要它是复数类型)。我该如何解决这个错误呢?

非常感谢任何帮助。


我想你可以重写这个函数,让my_func在一个参数a上进行原地操作。 - hilberts_drinking_problem
@hilberts_drinking_problem,这只是一个例子。实际函数涉及的内容更多,要对该函数进行原地替换将会很困难。 - learner
我理解你的观点,但在某种程度上,你正在尝试实现一些功能,这些功能应该在numba中实现。 - hilberts_drinking_problem
此外,似乎将 complex 替换为 np.complex64np.complex128 可以正常工作,而我尝试的其他选择则不行。 - hilberts_drinking_problem
据推测,这对应于一对32位浮点数和64位浮点数。 - hilberts_drinking_problem
nb.complex128(与 nb.c16 相同)和 nb.complex64(与 nb.c8 相同)也可以使用。 - aerobiomat
1个回答

1
你的问题与复数无关。如果你指定 a = np.zeros((10, 5), dtype=int),你将遇到同样的问题。
虽然 numpy 接受 python 原生数据类型 intfloatcomplex 并将它们视为 np.int32np.float64np.complex128,但是 numba 却不会自动这样做。
因此,每当你在 jitted 函数中指定数据类型时,你要使用 numpy 数据类型:
import numpy as np
from numba import njit

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=np.complex128)
    a[idx] = 10
    return a

my_func(4)

或者您可以直接导入 numba 数据类型:

import numpy as np
from numba import njit, complex128

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=complex128)
    a[idx] = 10
    return a

my_func(4)

或通过 types:
import numpy as np
from numba import njit, types

@njit
def my_func(idx):
    a = np.zeros((10, 5), dtype=types.complex128)
    a[idx] = 10
    return a

my_func(4)

据我所知,使用这些选项中的任何一个都没有什么区别。这里是numba文档的相关部分。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接