一系列矩阵的快速乘法

4

最快的运行方式是什么:

    reduce(lambda x,y : x@y, ls)

在Python中?

针对矩阵列表ls。我没有Nvidia GPU,但我有很多CPU核心可以使用。我认为我可以并行处理过程(将其拆分为log次迭代),但对于小型(1000x1000)矩阵,这实际上是最差的。这是我尝试的代码:

from multiprocessing import Pool
import numpy as np
from itertools import zip_longest

def matmul(x):
    if x[1] is None:
        return x[0]
    return x[1]@x[0]

def fast_mul(ls):
    while True:
        
        n = len(ls)
        if n == 0:
            raise Exception("Splitting Error")
        if n == 1:
            return ls[0]
        if n == 2:
            return ls[1]@ls[0]

        with Pool(processes=(n//2+1)) as pool:
            ls = pool.map(matmul, list(zip_longest(*[iter(ls)]*2)))
    


1
很奇怪 - 你似乎已经在做一个明智的事情了(利用数组乘法是可结合的这一事实)。如果你把整个 while 循环放在一个单独的 with Pool(process=num_cores) as pool: 中(对于适当的 num_cores 值),并且删除循环内部的 with 语句,会有帮助吗? - alani
嘿!看起来它加快了一点,但仍然比朴素的方法慢。 - Yotam Vaknin
@YotamVaknin 我不会创建与所需乘法数量相同的进程,只需使用系统并发性(默认值),池的重点是在有许多任务需要处理时重复使用进程。 - jdehesa
除此之外,您可以将代码简化为 pool.starmap(matmul, zip(ls[::2], ls[1::2])),并且我会更改它以处理大小不是2的幂的输入,使用 ls_next = pool.starmap(...) 然后 if len(ls) % 2 != 0: ls_next.append(ls[-1]),最后 ls = ls_next - jdehesa
此外,请记住,发送所有数据需要时间。使用NumPy内置的GIL版本可能会更快。 - Mad Physicist
显示剩余3条评论
2个回答

2

编辑:增加了另一种可能的函数。

编辑:我使用np.linalg.multi_dot添加了结果,期望它比其他方法更快,但实际上它却要慢得多。我想这是因为它是设计用于其他类型的用例。


我不确定你能否得到比这更快的速度。以下是针对数据为3D矩阵数组的缩减的几种不同实现:

from multiprocessing import Pool
from functools import reduce
import numpy as np
import numba as nb

def matmul_n_naive(data):
    return reduce(np.matmul, data)

# If you don't care about modifying data pass copy=False
def matmul_n_binary(data, copy=True):
    if len(data) < 1:
        raise ValueError
    data = np.array(data, copy=copy)
    n, r, c = data.shape
    dt = data.dtype
    s = 1
    while (n + s - 1) // s > 1:
        a = data[:n - s:2 * s]
        b = data[s:n:2 * s]
        np.matmul(a, b, out=a)
        s *= 2
    return np.array(a[0])

def matmul_n_pool(data):
    if len(data) < 1:
        raise ValueError
    lst = data
    with Pool() as pool:
        while len(lst) > 1:
            lst_next = pool.starmap(np.matmul, zip(lst[::2], lst[1::2]))
            if len(lst) % 2 != 0:
                lst_next.append(lst[-1])
            lst = lst_next
    return lst[0]

@nb.njit(parallel=False)
def matmul_n_numba_nopar(data):
    res = np.eye(data.shape[1], data.shape[2], dtype=data.dtype)
    for i in nb.prange(len(data)):
        res = res @ data[i]
    return res

@nb.njit(parallel=True)
def matmul_n_numba_par(data):
    res = np.eye(data.shape[1], data.shape[2], dtype=data.dtype)
    for i in nb.prange(len(data)):  # Numba knows how to do parallel reductions correctly
        res = res @ data[i]
    return res

def matmul_n_multidot(data):
    return np.linalg.multi_dot(data)

同时进行测试:

# Test
import numpy as np

np.random.seed(0)
a = np.random.rand(10, 100, 100) * 2 - 1
b1 = matmul_n_naive(a)
b2 = matmul_n_binary(a)
b3 = matmul_n_pool(a)
b4 = matmul_n_numba_nopar(a)
b5 = matmul_n_numba_par(a)
b6 = matmul_n_multidot(a)
print(np.allclose(b1, b2))
# True
print(np.allclose(b1, b3))
# True
print(np.allclose(b1, b4))
# True
print(np.allclose(b1, b5))
# True
print(np.allclose(b1, b6))
# True

以下是一些基准测试结果,似乎没有一个一致的胜者,但“naive”解决方案在各个方面都表现不错,二进制和Numba有所不同,进程池并不是很好,np.linalg.multi_dot 在方阵中似乎并不是非常优越。

import numpy as np

# 10 matrices 1000x1000
np.random.seed(0)
a = np.random.rand(10, 1000, 1000) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 121 ms ± 6.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_binary(a)
# 165 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_numba_nopar(a)
# 108 ms ± 510 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit matmul_n_numba_par(a)
# 244 ms ± 7.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_multidot(a)
# 132 ms ± 2.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# 200 matrices 100x100
np.random.seed(0)
a = np.random.rand(200, 100, 100) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 4.4 ms ± 226 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_binary(a)
# 13.4 ms ± 299 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_numba_nopar(a)
# 9.51 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_numba_par(a)
# 4.93 ms ± 146 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit matmul_n_multidot(a)
# 1.14 s ± 22.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# 300 matrices 10x10
np.random.seed(0)
a = np.random.rand(300, 10, 10) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 526 µs ± 953 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_binary(a)
# 152 µs ± 508 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit matmul_n_pool(a)
# 610 ms ± 5.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_numba_nopar(a)
# 239 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_numba_par(a)
# 175 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit matmul_n_multidot(a)
# 3.68 s ± 87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# 1000 matrices 10x10
np.random.seed(0)
a = np.random.rand(1000, 10, 10) * 0.1 - 0.05
%timeit matmul_n_naive(a)
# 1.56 ms ± 4.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_binary(a)
# 392 µs ± 790 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_pool(a)
# 727 ms ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit matmul_n_numba_nopar(a)
# 589 µs ± 356 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_numba_par(a)
# 451 µs ± 1.68 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit matmul_n_multidot(a)
# Never finished...

你能测试一下我的25行或更少的hack吗?这不是一个严肃的解决方案,但我只是想发布它。 - Mad Physicist
实际上,我也发现了 multi_dot。我怀疑它会是最快的。 - Mad Physicist
@MadPhysicist 我添加了这个基准测试,希望能找到它的性能表现,但在大多数情况下似乎要慢得多。我猜它一定是针对某种特定的使用情况进行了优化。 - jdehesa
1
我进行了一些简单的计时。值得看一下非正方形的情况。我认为你会看到一个相当大的差异。 - Mad Physicist
1
multi_dot(以及子函数)的代码位于np.linalg.linalg.py中。除了对np.dot的调用之外,它全部都是Python,专注于链接dot调用的“最佳”顺序。 - hpaulj

2
有一个可实现此功能的函数:np.linalg.multi_dot,据说它是针对最佳计算顺序进行了优化的。
np.linalg.multi_dot(ls)

实际上,文档中说的与您最初的措辞非常接近:

Think of multi_dot as:

def multi_dot(arrays): return functools.reduce(np.dot, arrays)

您也可以尝试使用np.einsum,它允许您对多达25个矩阵进行乘法运算:

from string import ascii_lowercase

ls = [...]
index = ','.join(ascii_lowercase[x:x + 2] for x in range(len(ls)))
index += f'->{index[0]}{index[-1]}'
np.einsum(index, *ls)

时间

简单情况:

ls = np.random.rand(100, 1000, 1000) - 0.5

%timeit reduce(lambda x, y : x @ y, ls)
4.3 s ± 76.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
4.35 s ± 84.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
4.86 s ± 68.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
5.24 s ± 66.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

更复杂的情况:

ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 2000, 500) - 0.5)]

%timeit reduce(lambda x, y : x @ y, ls)
7.94 s ± 96.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
7.91 s ± 33.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
9.38 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
2.03 s ± 52.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

请注意,由multi_dot完成的前期工作在简单情况下具有负面效果(更令人惊讶的是,lambda比原始运算符更快),但在不太直观的情况下可以节省75%的时间。
因此,为了完整起见,这里提供一个不太平凡的情况:
ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 400, 300) - 0.5)]

%timeit reduce(lambda x, y : x @ y, ls)
245 ms ± 8.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
245 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
284 ms ± 12.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
638 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

实际上,对于大多数一般情况来说,您最初的reduce调用实际上已经足够好了。我的建议是使用operator.matmul代替lambda。


为什么只限于乘以25个矩阵?这与字母表中的字母数量有关吗? - jkr
@jakub。已修复。感谢你的发现 :) - Mad Physicist
@jdehesa。看起来multi_dot在大多数情况下确实要慢得多。只是恰好在那种情况下效果更好。 - Mad Physicist
我对以前的问题的记忆是,只有当数组的形状不同时,multi_dot 才会产生差异。在开始时需要一些时间来确定哪些 dot 组合可以最快地减小问题规模。如果所有数组的大小都相同,则没有帮助,并且额外的前期工作会更慢。 - hpaulj

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接