在Numpy中实现低内存使用的大矩阵乘法

3
我有一个复杂的矩阵乘法,有几十万行和列。在某个时刻,内存使用量增长到100%,然后计算机就会冻结,我不得不手动重启它。
我尝试过使用Numba(将代码写在带有装饰器的函数中)和Dask(将numpy数组转换为da.from_array(var, chunk)),但都没有成功。我对这些工具都不是专家。
我已经阅读了很多类似问题的解答,但没有找到适合我的问题的好解决方案。
一个最小可复现的示例可能是:
m = 100000
n = 100000
a1 = np.random.rand(m)
a2 = np.random.rand(n)
c = np.random.rand(m)+1j*np.random.rand(m)
b = np.random.rand(n)+1j*np.random.rand(n)
A = np.exp(1j*np.outer(a1,a2))
d = c*np.dot(A,b)

在内存使用方面,什么是解决它的最佳选项?(不一定是最快的)

1
np.einsum在内存使用上更高效。 - MSS
1
@MSS,比dot更高效?你确定吗?但是A`呢?那会生成多个(m,n)缓冲区。 - hpaulj
1
你有空间来存储一个(n.m)的复数数组吗? - hpaulj
1
m维度分成批次并对其进行迭代可能会起作用。 - hpaulj
1
m维度分成批次并对其进行迭代可能会起作用。 - hpaulj
显示剩余11条评论
2个回答

4

主要问题

主要问题是1j*np.outer(a1,a2)占用了100_000 * 100_000 * (8 * 2) = 149 GiB的空间。此外,np.exp需要读取这个矩阵并生成一个相同大小的矩阵,所以你至少需要 ~300 GiB 的内存来完成这个操作。这是非常庞大且低效的。

你应该尽量避免创建矩阵A(包括类似的临时矩阵)


快速、内存高效的解决方案

Numba可以在这种情况下提供帮助:您可以即时计算数组d,避免使用大量临时矩阵。以下是一个优化过的Numba代码示例:

import numba as nb
import numpy as np

@nb.njit('(float64[::1], float64[::1], complex128[::1], complex128[::1])', parallel=True)
def compute(a1, a2, b, c):
    m, n = a1.size, a2.size
    assert n == m  # seems already mantatory in the initial code
    tmpDot = np.zeros(n, dtype=np.complex128)
    for i in nb.prange(n):
        for j in range(n):
            tmpDot[i] += np.exp(1j * (a2[j] * a1[i])) * b[j]
    return c * tmpDot

m = 100000
n = 100000
a1 = np.random.rand(m)
a2 = np.random.rand(n)
c = np.random.rand(m)+1j*np.random.rand(m)
b = np.random.rand(n)+1j*np.random.rand(n)
d = compute(a1, a2, b, c)

这段代码的内存占用相比初始代码非常小,仅为几兆字节。因此,它所需的内存量约为原来的100,000倍!此外,我还预计它的运行速度会显著加快(因为它采用了多线程,并且更好地利用了CPU缓存和RAM)。在我的机器上只需17.1秒(而我甚至无法运行初始代码)!

1
这个运行得非常好! - user536696
1
这个效果非常棒! - user536696
1
这个效果非常棒! - undefined

1

估计内存使用情况:

m float64  # a1 = np.random.rand(m)
n float64  # a2 = np.random.rand(n)
m complex128   # c = np.random.rand(m)+1j*np.random.rand(m)
n complex   #b = np.random.rand(n)+1j*np.random.rand(n)

这些 b 和 c 行将会有几个复杂的临时数组,但最终只会留下一个。

(m,n) complex   # A = np.exp(1j*np.outer(a1,a2))

outer 创建了一个 (m,n) 复数;1j* 创建了另一个复数;exp 创建了另一个复数

可以试一下

A = np.zeros((m,n), dtype=complex)
np.outer(1j*a1,a2, out=A)
np.exp(A, out=A)

最后:
m complex    # d = c*np.dot(A,b)

dot函数将(m,n)与(n,)相乘得到(m,)的结果。可以通过一些方法使其更加紧凑。

np.multiply(np.dot(A,b), c, out=c))

dot接受一个out参数,但我没有一个空闲的(m,)复数

使用out可能会节省一些内存,消除一些(m,n)复杂的临时缓冲区。甚至可能节省一点时间。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接