我正在尝试学习使用Numba模块。到目前为止,由于与NumPy的一些问题接口,我还没有能够使任何东西工作。这是我正在运行的代码(来自Numba文档)和我收到的错误:
from numba import jit
import numpy as np
x = np.arange(100).reshape(10, 10)
@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a): # Function is compiled to machine code when called the first time
trace = 0.0
for i in range(a.shape[0]): # Numba likes loops
trace += np.tanh(a[i, i]) # Numba likes NumPy functions
return a + trace # Numba likes NumPy broadcasting
print(go_fast(x))
Traceback (most recent call last):
File "C:/Users/JoHn/Documents/Current Classes/MEEN575_Optimization/HW6/Optimal_controller/angle_wrapping.py", line 84, in <module>
print(go_fast(x))
TypeError: expected dtype object, got 'numpy.dtype[float64]'
从一些搜索中了解到,这是一个已知错误,最近与新版本的Numba有关,需要较新的NumPy版本或类似的内容。但就我所知,我安装了最新的NumPy版本1.20。请问我做错了什么?我想明确说明的是,我从来没有完全理解如何在Python中清晰地设置环境,因此很可能我只是漏掉了一些显而易见的东西。