编辑:增加了另一种可能的函数。
编辑:我使用np.linalg.multi_dot
添加了结果,期望它比其他方法更快,但实际上它却要慢得多。我想这是因为它是设计用于其他类型的用例。
我不确定你能否得到比这更快的速度。以下是针对数据为3D矩阵数组的缩减的几种不同实现:
from multiprocessing import Pool
from functools import reduce
import numpy as np
import numba as nb
def matmul_n_naive(data):
return reduce(np.matmul, data)
def matmul_n_binary(data, copy=True):
if len(data) < 1:
raise ValueError
data = np.array(data, copy=copy)
n, r, c = data.shape
dt = data.dtype
s = 1
while (n + s - 1) // s > 1:
a = data[:n - s:2 * s]
b = data[s:n:2 * s]
np.matmul(a, b, out=a)
s *= 2
return np.array(a[0])
def matmul_n_pool(data):
if len(data) < 1:
raise ValueError
lst = data
with Pool() as pool:
while len(lst) > 1:
lst_next = pool.starmap(np.matmul, zip(lst[::2], lst[1::2]))
if len(lst) % 2 != 0:
lst_next.append(lst[-1])
lst = lst_next
return lst[0]
@nb.njit(parallel=False)
def matmul_n_numba_nopar(data):
res = np.eye(data.shape[1], data.shape[2], dtype=data.dtype)
for i in nb.prange(len(data)):
res = res @ data[i]
return res
@nb.njit(parallel=True)
def matmul_n_numba_par(data):
res = np.eye(data.shape[1], data.shape[2], dtype=data.dtype)
for i in nb.prange(len(data)):
res = res @ data[i]
return res
def matmul_n_multidot(data):
return np.linalg.multi_dot(data)
同时进行测试:
import numpy as np
np.random.seed(0)
a = np.random.rand(10, 100, 100) * 2 - 1
b1 = matmul_n_naive(a)
b2 = matmul_n_binary(a)
b3 = matmul_n_pool(a)
b4 = matmul_n_numba_nopar(a)
b5 = matmul_n_numba_par(a)
b6 = matmul_n_multidot(a)
print(np.allclose(b1, b2))
print(np.allclose(b1, b3))
print(np.allclose(b1, b4))
print(np.allclose(b1, b5))
print(np.allclose(b1, b6))
以下是一些基准测试结果,似乎没有一个一致的胜者,但“naive”解决方案在各个方面都表现不错,二进制和Numba有所不同,进程池并不是很好,np.linalg.multi_dot
在方阵中似乎并不是非常优越。
import numpy as np
np.random.seed(0)
a = np.random.rand(10, 1000, 1000) * 0.1 - 0.05
%timeit matmul_n_naive(a)
%timeit matmul_n_binary(a)
%timeit matmul_n_numba_nopar(a)
%timeit matmul_n_numba_par(a)
%timeit matmul_n_multidot(a)
np.random.seed(0)
a = np.random.rand(200, 100, 100) * 0.1 - 0.05
%timeit matmul_n_naive(a)
%timeit matmul_n_binary(a)
%timeit matmul_n_numba_nopar(a)
%timeit matmul_n_numba_par(a)
%timeit matmul_n_multidot(a)
np.random.seed(0)
a = np.random.rand(300, 10, 10) * 0.1 - 0.05
%timeit matmul_n_naive(a)
%timeit matmul_n_binary(a)
%timeit matmul_n_pool(a)
%timeit matmul_n_numba_nopar(a)
%timeit matmul_n_numba_par(a)
%timeit matmul_n_multidot(a)
np.random.seed(0)
a = np.random.rand(1000, 10, 10) * 0.1 - 0.05
%timeit matmul_n_naive(a)
%timeit matmul_n_binary(a)
%timeit matmul_n_pool(a)
%timeit matmul_n_numba_nopar(a)
%timeit matmul_n_numba_par(a)
%timeit matmul_n_multidot(a)
while
循环放在一个单独的with Pool(process=num_cores) as pool:
中(对于适当的num_cores
值),并且删除循环内部的with
语句,会有帮助吗? - alanipool.starmap(matmul, zip(ls[::2], ls[1::2]))
,并且我会更改它以处理大小不是2的幂的输入,使用ls_next = pool.starmap(...)
然后if len(ls) % 2 != 0: ls_next.append(ls[-1])
,最后ls = ls_next
。 - jdehesa