我在 Numba 上做错了什么?

3

我正在尝试学习使用Numba模块。到目前为止,由于与NumPy的一些问题接口,我还没有能够使任何东西工作。这是我正在运行的代码(来自Numba文档)和我收到的错误:

from numba import jit
import numpy as np

x = np.arange(100).reshape(10, 10)

@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a): # Function is compiled to machine code when called the first time
    trace = 0.0
    for i in range(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting

print(go_fast(x))

    Traceback (most recent call last):
File "C:/Users/JoHn/Documents/Current Classes/MEEN575_Optimization/HW6/Optimal_controller/angle_wrapping.py", line 84, in <module>
print(go_fast(x))
TypeError: expected dtype object, got 'numpy.dtype[float64]'

从一些搜索中了解到,这是一个已知错误,最近与新版本的Numba有关,需要较新的NumPy版本或类似的内容。但就我所知,我安装了最新的NumPy版本1.20。请问我做错了什么?我想明确说明的是,我从来没有完全理解如何在Python中清晰地设置环境,因此很可能我只是漏掉了一些显而易见的东西。


2
文档中可以得知,成功的类型推断是在nopython模式下编译的先决条件。您应该指定函数签名 - alex
这是您的工作示例吗?这应该可以直接使用(您使用哪个Numba版本?)或者x是其他dtype对象数组吗? - max9111
我正在使用版本0.45.1,这个例子是从numba文档中逐行复制的,但对我来说无效。 - Tarnarmour
2个回答

5

更新到0.53.1可以解决问题。这个问题也在我使用0.47.x时出现了。看起来更像是numpy的问题。解决方法之一是安装numpy >=1.20.0和numba v>0.52。

关于此问题的更多信息,请参考: https://github.com/numba/numba/issues/6041

P.S:不确定您是否仍然遇到此错误,只是想更新一下,我曾经遇到过类似的问题。


如果您安装了Anaconda,可能会阻止您更新Numba,除非您使用conda update --all一起更新所有的Anaconda。 (如果这导致Python像我一样消失,那么您可以使用conda install --force python.app来解决问题。) - Max

0

我曾经遇到过完全相同的问题。我尝试仅更新numpy和numba到历史上适用于numba的版本(正如其他答案所提到的),但这对我没有起作用。

解决我的问题的方法是完全更新conda和所有相关软件包。我使用以下命令进行了更新:

conda update -n base -c defaults conda

一定要重新启动计算机,因为还有其他包也会更新。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接