在Matlab中仅计算矩阵乘积的对角线。

4
有没有一种在Matlab中高效计算三个(或更多)矩阵乘积对角线的方法?具体而言,我想要:
diag(A'*B*A)

当A和B都很大时,可能需要很长时间。如果只有两个矩阵:

diag(B*A)

那我可以快速地这样做:
sum(B.*A',2)

现在我要计算三个矩阵的对角线,方法如下:

C = B*A;
ans = sum(A'.*C',2);

这很有帮助,但第一步操作(C = B*A)仍然需要很长时间。整个过程必须多次重复,导致我的代码运行需要几周的时间。例如,B大约是15k x 15k,A大约是32k x 15k。而且没有稀疏性。


有那么大且不稀疏的矩阵...我认为你是没有解决办法了。 - rayryeng
1
你的意思是A大约是15k x 32k,对吧?;-) - matheburg
看起来你无法避免进行完整的乘法运算(在你的代码中是C),因为为了进行第二次乘法运算,需要使用C中的所有元素。 - Luis Mendo
是的,我的意思是A是15k x 32k。感谢您发现了这个错误。 - user3682146
1个回答

3

首先,欢迎!老实说,这似乎有点困难。稍微改变一下至少能略微提高速度:

N = 5000;
A = rand(N,N*2);
B = rand(N,N);

t = cputime;
diag(A'*B*A);
disp(['Elapsed cputime ' num2str(cputime-t)]);

t=cputime;
C = B*A;
sum(A'.*C',2);
disp(['Elapsed cputime ' num2str(cputime-t)]);

% slightly better...
t=cputime;
C = B*A;
sum(A.*C)';
disp(['Elapsed cputime ' num2str(cputime-t)]);

% slightly better than slightly better...
t=cputime;
sum(A.*(B*A))';
disp(['Elapsed cputime ' num2str(cputime-t)]);

结果:

Elapsed cputime 82.2593
Elapsed cputime 28.6106
Elapsed cputime 25.8338
Elapsed cputime 25.7714

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接