使用Numba对nd-array进行最大值的并行化

4

我正在尝试使用Numba并行化一个Python函数,该函数接受两个numpy ndarrays alphabeta 作为参数。它们分别具有形如 (a,m,n)(b,m,n) 的形状,因此可以在后面的维度上进行广播。该函数计算参数的2D切片的矩阵点积(Frobenius乘积),并找出beta的切片,使得对于每个alpha的切片,它最大化了此乘积。代码如下:

@njit(parallel=True)
def parallel_value(alpha,beta):
    values=np.empty(alpha.shape[0])
    indices=np.empty(alpha.shape[0])
    for i in prange(alpha.shape[0]):
        dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
        index=np.argmax(dot)
        values[i]=dot[index]
        indices[i]=index
return values,indices

这段代码没有使用 njit 装饰器运行良好,但 Numba 编译器报错:

No implementation of function Function(<built-in function setitem>) found for signature:

>>>setitem(array(float64, 1d, C), int64, array(float64, 1d, C))

问题似乎出在这一行:values[i]=dot[index]。我不知道为什么会有问题。这个问题的原因是什么,我应该如何解决?

另外,在@njit的参数中添加nogil=True是否有任何优势?


看起来像是一个 bug,因为它在顺序执行时也能在 Numba 中工作。 - Jérôme Richard
1个回答

1

我成功地重现了你的问题。当运行以下代码时:

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def parallel_value(alpha,beta):
    values=np.empty(alpha.shape[0])
    indices=np.empty(alpha.shape[0])
    for i in prange(alpha.shape[0]):
        dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
        index=np.argmax(dot)
        values[i]=dot[index]
        indices[i]=index
    return values,indices


a, b, m, n = 6, 5, 4, 3
parallel_value(np.random.rand(a, m, n), np.random.rand(b, m, n))

我收到了错误信息:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
 
 >>> setitem(array(float64, 1d, C), int64, array(float64, 1d, C))
 
There are 16 candidate implementations:
      - Of which 16 did not match due to:
      Overload of function 'setitem': File: <numerous>: Line N/A.
        With argument(s): '(array(float64, 1d, C), int64, array(float64, 1d, C))':
       No match.

During: typing of setitem at <ipython-input-41-44518cf5219f> (11)

File "<ipython-input-41-44518cf5219f>", line 11:
def parallel_value(alpha,beta):
    <source elided>
        index=np.argmax(dot)
        values[i]=dot[index]
        ^

根据GitHub页面中此问题的描述,numba中的点操作可能存在问题。
当我使用显式循环重写代码时,似乎可以正常工作:
import numpy as np
from numba import njit, prange

@njit(parallel=True)
def parallel_value_numba(alpha,beta):
    values  = np.empty(alpha.shape[0])
    indices = np.empty(alpha.shape[0])
    for i in prange(alpha.shape[0]):
        dot = np.zeros(beta.shape[0])
        for j in prange(beta.shape[0]):
            for k in prange(beta.shape[1]):
                for l in prange(beta.shape[2]):
                    dot[j] += alpha[i,k,l]*beta[j, k, l]
        index=np.argmax(dot)
        values[i]=dot[index]
        indices[i]=index
    return values,indices

def parallel_value_nonumba(alpha,beta):
    values=np.empty(alpha.shape[0])
    indices=np.empty(alpha.shape[0])
    for i in prange(alpha.shape[0]):
        dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
        index=np.argmax(dot)
        values[i]=dot[index]
        indices[i]=index
    return values,indices


a, b, m, n = 6, 5, 4, 3
np.random.seed(42)
A = np.random.rand(a, m, n)
B = np.random.rand(b, m, n)
res_num   = parallel_value_numba(A, B)
res_nonum = parallel_value_nonumba(A, B)

print(f'res_num = {res_num}')
print(f'res_nonum = {res_nonum}')

输出:

res_num = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
       3.43988156]), array([1., 3., 1., 1., 1., 1.]))
res_nonum = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
       3.43988156]), array([1., 3., 1., 1., 1., 1.]))

据我所见,显式循环似乎不会影响性能。虽然我无法将其与没有循环的相同代码进行比较,因为这是numba,但我猜想这并不重要。
%timeit res_num   = parallel_value_numba(A, B)
%timeit res_nonum = parallel_value_nonumba(A, B)

输出:

The slowest run took 1472.03 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 4.92 µs per loop
10000 loops, best of 5: 76.9 µs per loop

最后,您可以通过使用numpy向量化代码来更高效地完成。这几乎与使用显式循环的numba一样快,而且您不必担心初始编译延迟。以下是您可能会这样做的方法:
def parallel_value_np(alpha,beta):
    alpha   = alpha.reshape(alpha.shape[0], 1, alpha.shape[1], alpha.shape[2])
    beta    = beta.reshape(1, beta.shape[0], beta.shape[1], beta.shape[2])
    dot     = np.sum(alpha*beta, axis=(2,3))
    indices = np.argmax(dot, axis = 1)
    values  = dot[np.arange(len(indices)), indices]
    return values,indices


res_np = parallel_value_np(A, B)
print(f'res_num = {res_np}')

%timeit res_num   = parallel_value_numba(A, B)

输出:

res_num = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
       3.43988156]), array([1, 3, 1, 1, 1, 1]))
The slowest run took 5.46 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 16.1 µs per loop

有趣。我会在有时间仔细查看并进行一些测试后接受这个答案。 - mikefallopian
2
当然,不要着急。我们在这里不是为了让我们的答案被接受,而是为了尽力帮助彼此。如果我的答案对你没有足够的帮助,请不要接受它。那会向人们表明你仍在寻找更好的答案。 - yann ziselman

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接