如何在Matlab中计算两个矩阵的外积平方和减去一个共同的矩阵?

3
假设有三个大小为 n * n 的矩阵 X、Y、S。如何快速计算以下标量 b?
for i = 1:n
  b = b  + sum(sum((X(i,:)' * Y(i,:) - S).^2));
end

计算成本为O(n^3)。有一种快速计算两个矩阵外积的方法。具体来说,C矩阵很重要。
for i = 1:n
  C = C + X(i,:)' * Y(i,:);
end

可以使用不含for循环的C = A.'*B计算,时间复杂度只有O(n^2)。那是否存在一种更快的方式来计算b呢?


你是不是想说 X(i,:)'*Y(i,:)?否则就没有提到 Y - bla
是的,谢谢你指出错误。我已经纠正了它。 - messcode
2个回答

4

您可以使用:

X2 = X.^2;
Y2 = Y.^2;
S2 = S.^2;
b = sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));

鉴于您的例子

b=0;
for i = 1:n
   b = b  + sum(sum((X(i,:).' * Y(i,:) - S).^2));
end

我们可以先将求和操作移出循环:
b=0;
for i = 1:n
  b = b  + (X(i,:).' * Y(i,:) - S).^2;
end
b=sum(b(:))

我们知道可以把(a - b)^2写成a^2 - 2*a*b + b^2

b=0;
for i = 1:n
  b = b  + (X(i,:).' * Y(i,:)).^2 - 2.* (X(i,:).' * Y(i,:)) .*S + S.^2;
end
b=sum(b(:))

我们知道(a * b) ^ 2a^2 * b^2是一样的:

X2 = X.^2;
Y2 = Y.^2;
S2 = S.^2;
b=0;
for i = 1:n
  b = b  + (X2(i,:).' * Y2(i,:)) - 2.* (X(i,:).' * Y(i,:)) .*S + S2;
end
b=sum(b(:))

现在我们可以分别计算每个术语:
 b = sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));

下面是Octave测试的结果,比较了我的方法和@AndrasDeak提供的两种方法以及基于循环的原始解决方案,针对输入大小为500*500的情况:

===rahnema1 (B)===
Elapsed time is 0.0984299 seconds.

===Andras Deak (B2)===
Elapsed time is 7.86407 seconds.

===Andras Deak (B3)===
Elapsed time is 2.99158 seconds.

===Loop solution===
Elapsed time is 2.20357 seconds


n=500;
X= rand(n);
Y= rand(n);
S= rand(n);

disp('===rahnema1 (B)===')
tic
    X2 = X.^2;
    Y2 = Y.^2;
    S2 = S.^2;
    b=sum(sum(X2.' * Y2 - 2 * (X.' * Y ) .* S + n * S2));
toc
disp('===Andras Deak (B2)===')
tic
    b2 = sum(reshape((permute(reshape(X, [n, 1, n]).*Y, [3,2,1]) - S).^2, 1, []));
toc
disp('===Andras Deak (B3)===')
tic
    b3 = sum(reshape((reshape(X, [n, 1, n]).*Y - reshape(S.', [1, n, n])).^2, 1, []));
toc
tic
    b=0;
    for i = 1:n
      b = b  + sum(sum((X(i,:)' * Y(i,:) - S).^2));
    end
toc

1
非常好!-- 我猜 sum(reshape(...))sum(sum(...)) 更昂贵?这是因为 reshape 函数的开销吗?-- 为了避免双重求和,在 Octave 中,您可以执行 sum((...)(:)),而在 MATLAB R2018b 中,您现在可以执行 sum(...,'all'),这两种方法都比双重求和更加优雅,并且应该更快。 - Cris Luengo
谢谢。我不确定 sum-reshapesum-sum 哪个更好。在这里,我使用 sum-sum 反映了原始代码,并且可能更易读。sum((...)(:)) 可能会使可读性复杂化,而 sum(...,'all') 使用了丑陋的字符串参数传递方式。比较它们,我认为在当前问题中速度上的差异可以忽略不计。也许我们需要一个更优雅的函数/操作符? - rahnema1
双重求和需要一个中间数组,因此原则上它会做更多的工作。但是没错,在这种情况下,我确定这不是可衡量的差异。我同意你的易读性评论。一个新的运算符会很棒。就像sumall=@(x)sum(x(:))这样的东西。但这意味着约20个新的运算符(meanall, stdall, maxall等等),你不能只做一个而不做其他所有的。 :) - Cris Luengo

3

您可能没有多余的时间复杂度,但是您可以利用向量化来消除循环并尽可能多地利用低级别代码和缓存。它是否真的更快取决于您的维度,所以您需要进行一些计时测试以确定是否值得这样做:

% dummy data
n = 3;
X = rand(n);
Y = rand(n);
S = rand(n);

% vectorize
b2 = sum(reshape((permute(reshape(X, [n, 1, n]).*Y, [3,2,1]) - S).^2, 1, []));

% check
b - b2 % close to machine epsilon i.e. zero

发生的情况是我们在其中一个数组中插入了一个新的单例维度,导致一个大小为[n, 1, n]的数组与一个[n, n]的数组重叠,后者隐含地与[n, n, 1]相同。重叠的第一个索引对应于循环中的i,其余两个索引对应于每个i的二元乘积的矩阵索引。然后,我们重新排列索引,将"i"索引放在最后,以便我们可以再次广播结果,并且大小(隐式)为[n, n, 1]。然后,我们拥有的是一个大小为[n, n, n]的矩阵,其中前两个索引是原始矩阵的矩阵索引,最后一个索引对应于i。然后,我们只需要取平方并对每个项求和(而不是两次求和,我将数组重塑为一行并进行了一次求和)。
上述内容的轻微变化是对S进行转置,而不是3D数组,这可能会更快(同样,您应该测试时间)。
b3 = sum(reshape((reshape(X, [n, 1, n]).*Y - reshape(S.', [1, n, n])).^2, 1, []));

在性能方面,reshape是免费的(它只是重新解释数据,而不进行复制),但是permute/transpose通常会导致数据被复制时的性能损失。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接