我使用Anaconda分发版的Python和Numba,编写了以下Python函数,该函数将一个稀疏矩阵A
(以CSR格式存储)乘以一个密集向量x
:
@jit
def csrMult( x, Adata, Aindices, Aindptr, Ashape ):
numRowsA = Ashape[0]
Ax = numpy.zeros( numRowsA )
for i in range( numRowsA ):
Ax_i = 0.0
for dataIdx in range( Aindptr[i], Aindptr[i+1] ):
j = Aindices[dataIdx]
Ax_i += Adata[dataIdx] * x[j]
Ax[i] = Ax_i
return Ax
这里的A
是一个大型的scipy
稀疏矩阵。
>>> A.shape
( 56469, 39279 )
# having ~ 142,258,302 nonzero entries (so about 6.4% )
>>> type( A[0,0] )
dtype( 'float32' )
而且 x
是一个 numpy
数组。下面是调用上述函数的代码片段:
x = numpy.random.randn( A.shape[1] )
Ax = A.dot( x )
AxCheck = csrMult( x, A.data, A.indices, A.indptr, A.shape )
注意使用修饰符
@jit
,它告诉Numba对函数csrMult()
进行即时编译。在我的实验中,
csrMult()
函数的运行速度大约是scipy
的.dot()
方法的两倍。这对于Numba来说是一个非常令人印象深刻的结果。然而,MATLAB执行这个矩阵-向量乘法的速度大约比
csrMult()
快6倍。我认为这是因为MATLAB在执行稀疏矩阵-向量乘法时使用了多线程。
问题:
如何在使用Numba时并行化外部的for
循环?Numba以前有一个
prange()
函数,用于简化并行化尴尬的并行for
循环。不幸的是,Numba现在没有prange()
函数[ 实际上,这是错误的,请参见下面的编辑]。那么,在Numba的prange()
函数消失之后,正确的并行化这个for
循环的方法是什么?当从Numba中删除
prange()
时,Numba的开发人员有什么替代方案?
编辑1:
我更新到了最新版本的Numba,即.35版本,prange()
又回来了!它没有包含在我使用的版本.33中。
这是个好消息,但不幸的是,当我尝试使用prange()
并行化我的for循环时,我收到了一个错误消息。这里是Numba文档中的一个并行for循环示例(请参见第1.9.2节“显式并行循环”),下面是我的新代码:
from numba import njit, prange
@njit( parallel=True )
def csrMult_numba( x, Adata, Aindices, Aindptr, Ashape):
numRowsA = Ashape[0]
Ax = np.zeros( numRowsA )
for i in prange( numRowsA ):
Ax_i = 0.0
for dataIdx in range( Aindptr[i],Aindptr[i+1] ):
j = Aindices[dataIdx]
Ax_i += Adata[dataIdx] * x[j]
Ax[i] = Ax_i
return Ax
当我使用上面给定的代码片段调用此函数时,我收到以下错误:
AttributeError:在 nopython (转换为 parfors) 'SetItem' 失败 对象没有属性 'get_targets'。
鉴于上述尝试使用 prange
崩溃,我的问题如下:
什么是正确的方式(使用 prange
或其他方法)来并行化这个 Python for
循环?
如下所示,在 C++ 中并行化类似的 for 循环并获得 20-omp-threads 上的 8x 加速非常容易。因为稀疏矩阵向量乘法是科学计算中的基本操作,因此必须有一种使用 Numba 的方法来处理它。
编辑2:
这里是我对csrMult()
的 C++ 版本。在 C++ 版本中并行化for()
循环可以使代码的测试速度快大约 8 倍。这使我认为,在使用 Numba 时 Python 版本应该能够实现类似的加速。
void csrMult(VectorXd& Ax, VectorXd& x, vector<double>& Adata, vector<int>& Aindices, vector<int>& Aindptr)
{
// This code assumes that the size of Ax is numRowsA.
#pragma omp parallel num_threads(20)
{
#pragma omp for schedule(dynamic,590)
for (int i = 0; i < Ax.size(); i++)
{
double Ax_i = 0.0;
for (int dataIdx = Aindptr[i]; dataIdx < Aindptr[i + 1]; dataIdx++)
{
Ax_i += Adata[dataIdx] * x[Aindices[dataIdx]];
}
Ax[i] = Ax_i;
}
}
}
@vectorize
或@guvectorize
(以生成 ufuncs)。甚至可能需要将内部循环拆分为另一个函数。 - f0xdx