快速计算(1:N)'*(1:N)的方法

6
我正在寻找一种快速计算的方法。
(1:N)'*(1:N)

对于相当大的N,我感觉问题的对称性使得实际执行乘法和加法是浪费的。


2
你为什么在意呢?用朴素的方法,我在我的机器上计算不到1秒就耗尽了内存... 你的情况可能会有所不同。 - John
7
“为什么”这个问题非常重要。在理论上,三角形法则可以减少计算操作。在实践中,矩阵乘法代码实现效率极高。即使你编写了直接的C语言代码并将操作次数降到最小,也很可能仍然无法超过完整的矩阵乘法。为什么?因为当CPU操作数量很少时,内存带宽会占主导地位,所以优化的缓存感知访问模式是提高运行速度的关键。 - Peter
彼得所说的。但出于学术目的,我发布了两个解决方案,它们的速度减慢了近一个数量级......考虑到这点,它们也不差。 - chappjc
2
@Peter,听起来像是正在形成一个答案的样子... - StrongBad
1
@StrongBad:只是挑刺一下:按照这个矩阵乘积的数学定义,实际上并没有涉及到加法。 - knedlsepp
显示剩余2条评论
4个回答

14
为什么要这样做的问题真的很重要。
从理论上讲,其他答案中建议的三角形方法会节省您的操作。@jgmao的答案尤其有趣,可以减少乘法。
在实际意义上,当编写快速代码时,要最小化的度量标准不再是CPU操作次数。内存带宽在拥有如此少的CPU操作时占主导地位,因此调整缓存感知的访问模式是如何使其快速运行的关键。矩阵乘法代码被实现得非常高效,因为它是如此常见的操作,并且每个值得一试的BLAS数值库的实现都将使用优化的访问模式和SIMD计算。
即使您编写了直接的C代码并将操作次数降至理论最低水平,您也可能仍然无法击败完整的矩阵乘法。这归结于找到最接近您操作的数值原语。
总之,有一个BLAS操作比DGEMM(矩阵乘法)更接近。它称为DSYRK,级别-k更新,可以用于完全等于A'*A。我很久以前写的这个MEX函数在这里。我很久没有碰过它了,但是当我第一次写它时,它确实可以运行,并且确实比直接的A'*A运行得更快。
/* xtrx.c: calculates x'*x taking advantage of the symmetry.
Peter Boettcher <email removed>
Last modified: <Thu Jan 23 13:53:02 2003> */

#include "mex.h"

const double one = 1;
const double zero = 0;

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  double *x, *z;
  int i, j, mrows, ncols;

  if(nrhs!=1) mexErrMsgTxt("One input required.");

  x = mxGetPr(prhs[0]);
  mrows = mxGetM(prhs[0]);
  ncols = mxGetN(prhs[0]);

  plhs[0] = mxCreateDoubleMatrix(ncols,ncols, mxREAL);
  z = mxGetPr(plhs[0]);

  /* Call the FORTRAN BLAS routine for rank k update */
  dsyrk_("U", "T", &ncols, &mrows, &one, x, &mrows, &zero, z, &ncols);

  /* Result is in the upper triangle.  Copy it down the lower part */
  for(i=0; i<ncols; i++)
      for(j=i+1; j<ncols; j++)
          z[i*ncols + j] = z[j*ncols + i];
}

1
非常好的答案!感谢您提供这段代码。它非常具有说明性。 - chappjc

6
MATLAB的矩阵乘法通常非常快,但是有几种方法可以获取仅上三角矩阵。它们比朴素地计算 v'*v(或使用调用更适当的BLAS中对称秩k更新函数的MEX包装器)要慢得多。无论如何,这里有几个仅使用MATLAB的解决方案:
第一个使用线性索引
% test vector
N = 1e3;
v = 1:N;

% compute upper triangle of product
[ii, jj] = find(triu(ones(N)));
upperMask = false(N,N);
upperMask(ii + N*(jj-1)) = true;
Mu = zeros(N);
Mu(upperMask) = v(ii).*v(jj); % other lines always the same computation

% validate
M = v'*v;
isequal(triu(M),Mu)

下面这种方式不会比朴素方法更快,但是这里提供另一种使用 bsxfun 计算下三角的解决方案:

Ml = bsxfun(@(x,y) [zeros(y-1,1); x(y:end)*y],v',v);

对于上半部分三角形:

Mu = bsxfun(@(x,y) [x(1:y)*y; zeros(numel(x)-y,1)],v',v);
isequal(triu(M),Mu)

对于整个矩阵,使用cumsum的另一种解决方案(其中v=1:N),这种方法实际上速度接近。

M = cumsum(repmat(v,[N 1]));

也许这些可以成为更好东西的起点。

5

这比 (1:N).'*(1:N) 快3倍,如果返回一个int32的结果是可以接受的(如果数字足够小,使用int16代替int32甚至更快):

N = 1000;
aux = int32(1:N);
result = bsxfun(@times,aux.',aux);

基准测试:

>> N = 1000; aux = int32(1:N); tic, for count = 1:1e2, bsxfun(@times,aux.',aux); end, toc
Elapsed time is 0.734992 seconds.

>> N = 1000; aux = 1:N; tic, for count = 1:1e2, aux.'*aux; end, toc
Elapsed time is 2.281784 seconds.

请注意,aux.'*aux不能用于aux = int32(1:N)
正如@DanielE.Shub所指出的那样,如果需要将结果作为double矩阵输出,则需要进行最终转换,此时收益非常小。
>> N = 1000; aux = int32(1:N); tic, for count = 1:1e2, double(bsxfun(@times,aux.',aux)); end, toc
Elapsed time is 2.173059 seconds.

1
公平地说,我认为你需要将“result”转换为double类型。但是在我的机器上,使用int32仍然略快,而使用int16则快得多。 - StrongBad

3

鉴于输入的特殊有序结构,考虑N=4的情况。

(1:4)'*(1:4) = [1 2 3 4
                2 4 6 8
                3 6 9 12
                4 8 12 16]

您会发现第一行只是 (1:N),从第二行(j=2)开始,这一行的值是前一行(j=1)加上 (1:N)。因此,1. 您不需要进行多次乘法。相反,您可以通过 N*N 次加法来生成它。2. 由于输出是对称的,只需计算输出矩阵的一半。因此,总计算量为 (N-1)+(N-2)+...+1 = N^2 / 2 次加法。

根据2:2:8需要执行的操作数量,您可能不需要执行除内存连接之外的任何操作。 - StrongBad

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接