新建进程或复制矩阵(如果进程被重用)的成本超过了矩阵乘法的成本。无论如何,
numpy.dot()
可以自行利用不同的 CPU 核心。
可以通过在不同的进程中计算结果的不同行来分配矩阵乘法,例如,给定输入矩阵
a
和
b
,则结果的
(i,j)
元素为:
out[i,j] = sum(a[i,:] * b[:,j])
所以第 i
行可以计算为:
import numpy as np
def dot_slice(a, b, out, i):
t = np.empty_like(a[i,:])
for j in xrange(b.shape[1]):
np.multiply(a[i,:], b[:,j], t).sum(axis=1, out=out[i,j])
numpy
数组接受切片作为索引,例如,a[1:3,:]
返回第2行和第3行。
a
、b
是只读的,因此它们可以被子进程继承(在Linux上利用写时复制),结果使用共享数组计算。在计算过程中只复制索引:
import ctypes
import multiprocessing as mp
def dot(a, b, nprocesses=mp.cpu_count()):
"""Perform matrix multiplication using multiple processes."""
if (a.shape[1] != b.shape[0]):
raise ValueError("wrong shape")
mp_arr = mp.RawArray(ctypes.c_double, a.shape[0]*b.shape[1])
np_args = mp_arr, (a.shape[0], b.shape[1]), a.dtype
pool = mp.Pool(nprocesses, initializer=init, initargs=(a, b)+np_args)
for i in pool.imap_unordered(mpdot_slice, slices(a.shape[0], nprocesses)):
print("done %s" % (i,))
pool.close()
pool.join()
return tonumpyarray(*np_args)
在哪里:
def mpdot_slice(i):
dot_slice(ga, gb, gout, i)
return i
def init(a, b, *np_args):
"""Called on each child process initialization."""
global ga, gb, gout
ga, gb = a, b
gout = tonumpyarray(*np_args)
def tonumpyarray(mp_arr, shape, dtype):
"""Convert shared multiprocessing array to numpy array.
no data copying
"""
return np.frombuffer(mp_arr, dtype=dtype).reshape(shape)
def slices(nitems, mslices):
"""Split nitems on mslices pieces.
>>> list(slices(10, 3))
[slice(0, 4, None), slice(4, 8, None), slice(8, 10, None)]
>>> list(slices(1, 3))
[slice(0, 1, None), slice(1, 1, None), slice(2, 1, None)]
"""
step = nitems // mslices + 1
for i in xrange(mslices):
yield slice(i*step, min(nitems, (i+1)*step))
测试它:
def test():
n = 100000
a = np.random.rand(50, n)
b = np.random.rand(n, 60)
assert np.allclose(np.dot(a,b), dot(a,b, nprocesses=2))
在Linux上,这个多进程版本的性能与
使用线程并在计算期间释放GIL(在C扩展中)的解决方案相同:
$ python -mtimeit -s'from test_cydot import a,b,out,np' 'np.dot(a,b,out)'
100 loops, best of 3: 9.05 msec per loop
$ python -mtimeit -s'from test_cydot import a,b,out,cydot' 'cydot.dot(a,b,out)'
10 loops, best of 3: 88.8 msec per loop
$ python -mtimeit -s'from test_cydot import a,b; import mpdot' 'mpdot.dot(a,b)'
done slice(49, 50, None)
..[snip]..
done slice(35, 42, None)
10 loops, best of 3: 82.3 msec per loop
注意:测试已更改为在所有地方使用
np.float64
。