我正在尝试使用 numba
(目前使用的是版本 0.45.1
)来加速代码,并遇到了一个布尔索引的问题。代码如下:
from numba import njit
import numpy as np
n_max = 1000
n_arr = np.hstack((np.arange(1,3),
np.arange(3,n_max, 3)
))
@njit
def func(arr):
idx = np.arange(arr[-1]).reshape((-1,1)) < arr -2
result = np.zeros(idx.shape)
result[idx] = 10.1
return result
new_arr = func(n_arr)
我一运行代码,就会收到以下信息。
TypingError: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 2d, C), array(bool, 2d, C), float64)
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
In definition 9:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of setitem at C:/Users/User/Desktop/all python file/5.5.5/numba index broadcasting2.py (29)
请注意,最后一行的
(29)
对应于第 29 行,也就是 result[idx] = 10.1
这一行,这是我尝试为索引为 idx
的 result
赋值的代码行,idx
是一个二维布尔索引。
我想解释一下,在
@njit
中包含该语句 result[idx] = 10.1
是必须的。尽管我想在 @njit
中排除此语句,但我无法这样做,因为这行代码正好在我正在处理的代码中间。如果我坚持要在
@njit
中包括赋值语句 result[idx] = 10.1
,那么需要做出哪些改变才能使其工作?如果可能的话,请给出涉及二维布尔索引的 @njit
代码示例,以便运行。谢谢
@njit
中广播比 for 循环慢,我想知道这是否适用于大多数大数据操作?我之所以问这个问题,是因为虽然在这个例子中我只是在numpy
数组中进行广播,但在未来,我将在tensorflow
设置中进行广播,所以知道这个问题的答案将非常有帮助。 - mathguy