Numba的nopython模式无法接受二维布尔索引。

9

我正在尝试使用 numba(目前使用的是版本 0.45.1)来加速代码,并遇到了一个布尔索引的问题。代码如下:

from numba import njit
import numpy as np

n_max = 1000

n_arr = np.hstack((np.arange(1,3),
                   np.arange(3,n_max, 3)
                   ))

@njit
def func(arr):
    idx =  np.arange(arr[-1]).reshape((-1,1)) < arr -2
    result = np.zeros(idx.shape)
    result[idx] = 10.1
    return result

new_arr = func(n_arr)

我一运行代码,就会收到以下信息。

TypingError: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 2d, C), array(bool, 2d, C), float64)
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
    raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
In definition 9:
    TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
    raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of setitem at C:/Users/User/Desktop/all python file/5.5.5/numba index broadcasting2.py (29)

请注意,最后一行的 (29) 对应于第 29 行,也就是 result[idx] = 10.1 这一行,这是我尝试为索引为 idxresult 赋值的代码行,idx 是一个二维布尔索引。
我想解释一下,在 @njit 中包含该语句 result[idx] = 10.1 是必须的。尽管我想在 @njit 中排除此语句,但我无法这样做,因为这行代码正好在我正在处理的代码中间。
如果我坚持要在 @njit 中包括赋值语句 result[idx] = 10.1,那么需要做出哪些改变才能使其工作?如果可能的话,请给出涉及二维布尔索引的 @njit 代码示例,以便运行。
谢谢
1个回答

7
Numba目前不支持使用2D数组进行高级索引。请参见:https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#array-access。但是,您可以通过显式使用for循环重写函数而不依赖广播来获得等效的行为。
from numba import njit
import numpy as np

n_max = 1000

n_arr = np.hstack((np.arange(1,3),
                   np.arange(3,n_max, 3)
                   ))

def func(arr):
    idx =  np.arange(arr[-1]).reshape((-1,1)) < arr -2
    result = np.zeros(idx.shape)
    result[idx] = 10.1
    return result

@njit
def func2(arr):
    M = arr[-1]
    N = arr.shape[0]
    result = np.zeros((M, N))
    for i in range(M):
        for j in range(N):
            if i < arr[j] - 2:
                result[i, j] = 10.1

    return result

new_arr = func(n_arr)
new_arr2 = func2(n_arr)
print(np.allclose(new_arr, new_arr2))  # True

在我的电脑上,使用您提供的示例输入,func2 的速度大约比 func 快 3.5 倍。

2
有趣。看到在这种情况下,在 @njit 中广播比 for 循环慢,我想知道这是否适用于大多数大数据操作?我之所以问这个问题,是因为虽然在这个例子中我只是在 numpy 数组中进行广播,但在未来,我将在 tensorflow 设置中进行广播,所以知道这个问题的答案将非常有帮助。 - mathguy

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接