Matlab重复矩阵乘法-循环与内置性能比较

5
给定矩阵A,我需要乘以其他n个向量Bi(即i = 1...n)。A的大小可以为5000x5000,因此Bi可以为5000x1。
如果我按以下方式评估乘积:
for i=1:n
    product=A*Bi;
    % do something with product
end

结果比计算产品慢得多(数量级),例如:
%assume that S is a matrix that contains the vectors Bi as columns, i.e. S(:,i)=Bi, then:
results=A*S;   %stores all the products in matrix form
% do something with results

问题在于向量Bi的数量n可能太大而无法存储在内存中,例如n = 300000,因此我需要使用循环方法,在每次评估乘积时使用它,然后丢弃向量Bi。
为什么这种方法与直接乘法相比如此缓慢,有没有克服这个问题的方法?

3
这个主题的好文章是为什么MATLAB在矩阵乘法方面如此快? - Adriaan
严肃地说,MathWorks应该对此进行适当的基准测试,并将结果以大号霓虹绿字体打印在某个地方。这个问题已经被问了很多次,而且仍然在被问。显然,网络上的答案对一些人来说不够好,那么为什么MathWorks(具有更深入的了解源代码)不试着做一下呢?@xarz 问得没错。如果网络上的答案不能满足需求,那么显然没有一个足够好的答案可以回答这个问题。 - patrik
@patrik 也许你是对的,但我在stackoverflow上查找了一下,没有找到处理这个确切问题的主题。顺便说一句,如果你能在这里链接一些涉及到这个确切问题的参考资料,它们可能对未来的读者有用。谢谢。 - xanz
4个回答

4
您可以尝试分批循环处理,例如:
for i = 0:(n/k)-1
    product = A*S(:,(i*k+1):(i+1)*k)
end

调整k以找到最适合您速度和内存的平衡点。
MATLAB的循环较慢,因为它是一种解释性语言。所以它必须即时处理大量的东西。由于JIT编译器的存在,这些循环在今天得到了极大的改进,但与使用C语言编写和编译的内置函数相比仍然较慢。此外,它们使用真正先进的超快矩阵乘法算法,与通过循环实现的相对简单的算法相比,这也有助于加速您的计算过程。

1
在这种情况下,您还可以使用“parfor”并行处理。如果有足够的核心和复杂度(或计算规模),那么可能会显著提高速度。 - Adriaan
1
@Adriaan 可能值得添加为另一个答案。虽然如果我没记错的话,* 运算符已经并行化了,所以很难预测会有什么样的加速效果,而且它也无法解决内存限制问题。 - Dan
1
切片中存在一个错误,它从k开始,并且每个第k个元素都出现两次。 - Daniel
@Daniel 呀,我其实没有解决它,只是举了个例子。不过我会尝试纠正它的。 - Dan
我使用了这样的方法,并且认为它是最好的。我编写了一个方法,用于将一批50000个向量Bi相乘,然后存储结果以允许外部函数检索每个乘积。当缓冲区为空时,它会再次填充。基本上就是你建议的! - xanz

3
为简单起见,我的答案假定一个n×n的方阵A,但对于非方阵也是正确的。
您的循环方法使用矩阵向量乘法。朴素解法也是目前最好的解法,导致运行时间为O(n^2),重复n次。你最终获得的总运行时间为O(n^3)。
对于矩阵乘法,有更好的方法。目前已知的最佳算法只需要略少于O(n^2.4)的运行时间,使其在大量数字上快得多。
当使用矩阵乘法一次性乘以多个向量Bi时,您将获得更好的运行时间。这将无法实现纯矩阵乘法的性能,但使用更大的b切片可能是最快的内存有效解决方案。
以下是不同讨论方法的一些代码:
n=5000;
k=100;
A=rand(n,n);
S=rand(n,n);
workers=matlabpool('size');
%for a parfor solution, the batch size must be smaller because multiple batches are stred in memory at once
kparallel=k/workers;
disp('simple loop:');
tic;
for i = 1:n
    product = A*S(:,n);
end
toc
disp('batched loop:');
tic;
for i = 1:(n/k)
    product = A*S(:,(i-1)*k+1:(i)*k);
end
toc
disp('batched parfor loop:');
tic;
parfor i = 1:(n/kparallel)
    product = A*S(:,(i-1)*kparallel+1:(i)*kparallel);
end
toc
disp('matrix multiplication:');
tic;
A*S;
toc

谢谢,这是非常有趣的关于这些算法运行时间的评论,这解释了问题。 - xanz
1
"matlabpool('size')"在R2015a中已经无法使用,您需要使用"parpool"。 - Adriaan

1
除了@Dan的回答之外,如果您有足够的核心和足够大的操作使其盈利,您可以尝试并行处理(有关parfor的内存消耗的更多细节,请参见此答案):
parfor ii = 0:(n/k)-1
    product = A*S(:,(ii*k+1):(ii+1)*k)
end

我在 mtimes 文档中没有看到是否隐式支持多线程,不过我猜值得一试。

2
建议使用更大的批量大小,直到达到内存限制,而不是使用parfor。这样会更快。 - Daniel
@xarz,那么Daniel所说的确实是事实;矩阵乘法非常快,最好批量处理尽可能大的批次,而不是并行处理。 - Adriaan

0
为了对每个数组与矩阵进行乘法运算,只需将矩阵与一个矩阵相乘,该矩阵的列将是您想要的数组。
因此,如果您想要检查这个问题
如果
size(a)=3,3

那么

 a*b==horzcat(a*b(:,1),a*b(:,2),a*b(:,3)) 

是真的

这样可以节省循环的时间


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接