我成功地重现了你的问题。当运行以下代码时:
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def parallel_value(alpha,beta):
values=np.empty(alpha.shape[0])
indices=np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
a, b, m, n = 6, 5, 4, 3
parallel_value(np.random.rand(a, m, n), np.random.rand(b, m, n))
我收到了错误信息:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(array(float64, 1d, C), int64, array(float64, 1d, C))
There are 16 candidate implementations:
- Of which 16 did not match due to:
Overload of function 'setitem': File: <numerous>: Line N/A.
With argument(s): '(array(float64, 1d, C), int64, array(float64, 1d, C))':
No match.
During: typing of setitem at <ipython-input-41-44518cf5219f> (11)
File "<ipython-input-41-44518cf5219f>", line 11:
def parallel_value(alpha,beta):
<source elided>
index=np.argmax(dot)
values[i]=dot[index]
^
根据GitHub页面中
此问题的描述,numba中的点操作可能存在问题。
当我使用显式循环重写代码时,似乎可以正常工作:
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def parallel_value_numba(alpha,beta):
values = np.empty(alpha.shape[0])
indices = np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot = np.zeros(beta.shape[0])
for j in prange(beta.shape[0]):
for k in prange(beta.shape[1]):
for l in prange(beta.shape[2]):
dot[j] += alpha[i,k,l]*beta[j, k, l]
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
def parallel_value_nonumba(alpha,beta):
values=np.empty(alpha.shape[0])
indices=np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
a, b, m, n = 6, 5, 4, 3
np.random.seed(42)
A = np.random.rand(a, m, n)
B = np.random.rand(b, m, n)
res_num = parallel_value_numba(A, B)
res_nonum = parallel_value_nonumba(A, B)
print(f'res_num = {res_num}')
print(f'res_nonum = {res_nonum}')
输出:
res_num = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
3.43988156]), array([1., 3., 1., 1., 1., 1.]))
res_nonum = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
3.43988156]), array([1., 3., 1., 1., 1., 1.]))
据我所见,显式循环似乎不会影响性能。虽然我无法将其与没有循环的相同代码进行比较,因为这是numba,但我猜想这并不重要。
%timeit res_num = parallel_value_numba(A, B)
%timeit res_nonum = parallel_value_nonumba(A, B)
输出:
The slowest run took 1472.03 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 4.92 µs per loop
10000 loops, best of 5: 76.9 µs per loop
最后,您可以通过使用numpy向量化代码来更高效地完成。这几乎与使用显式循环的numba一样快,而且您不必担心初始编译延迟。以下是您可能会这样做的方法:
def parallel_value_np(alpha,beta):
alpha = alpha.reshape(alpha.shape[0], 1, alpha.shape[1], alpha.shape[2])
beta = beta.reshape(1, beta.shape[0], beta.shape[1], beta.shape[2])
dot = np.sum(alpha*beta, axis=(2,3))
indices = np.argmax(dot, axis = 1)
values = dot[np.arange(len(indices)), indices]
return values,indices
res_np = parallel_value_np(A, B)
print(f'res_num = {res_np}')
%timeit res_num = parallel_value_numba(A, B)
输出:
res_num = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
3.43988156]), array([1, 3, 1, 1, 1, 1]))
The slowest run took 5.46 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 16.1 µs per loop