有没有一个内置的Matlab函数可以计算二次型(x'*A*x)?

3

问题很简单:给定一个N x N对称矩阵A和一个N向量x,是否有内置的Matlab函数来计算x'*A*x?也就是说,是否有一个函数quadraticform使得y = quadraticform(A, x),而不是y = x'*A*x

显然,我可以直接做y = x'*A*x,但我需要更好的性能,看起来应该有一种方法利用以下两点:

  1. A是对称的
  2. 左右乘数是相同的向量

如果没有单个内置函数,是否有比x'*A*x更快的方法?或者,Matlab解析器是否足够聪明以优化x'*A*x?如果是这样,请指出文档中验证此事实的位置。


https://dev59.com/xkzSa4cB1Zd3GeqPq_TJ - kol
谢谢。虽然这的确快了一点,但我还是想让它保持开放状态,看看是否有其他建议(从技术上讲,这不是同一个问题,但达到了相同的目的)。sum(x.(Ax))没有利用对称性或重复性...这是一个非常普遍的计算,因此似乎会有内置函数... - dantswain
2个回答

7
我没有找到这样的内置函数,我有一个想法。 y=x'*A*x 可以写成 A(i,j)*x(i)*x(j)n^2 项和,其中 ij 运行从 1 到 n(其中 A 是一个 nxn 矩阵)。A 是对称的:对于所有的 ij,有 A(i,j) = A(j,i)。由于对称性,在和中每个项都出现了两次,除了那些 i 等于 j 的项。因此我们有 n*(n+1)/2 种不同的项。每个项有两个浮点乘法,所以天真的方法需要总共 n*(n+1) 个乘法。很容易看出,天真的计算 x'*A*x,即计算 z=A*x,然后 y=x'*z,也需要 n*(n+1) 个乘法。然而,有一种更快的方法来求和我们的 n*(n+1)/2 不同的项:对于每个 i,我们可以分解出 x(i),这意味着只需要 n*(n-1)/2+3*n 个乘法就足够了。但这并不能真正帮助:计算 y=x'*A*x 的运行时间仍然是 O(n^2)
因此,我认为二次形式的计算不能比 O(n^2) 更快,由于这也可以通过公式 y=x'*A*x 实现,所以特殊的“quadraticform”函数没有真正的优势。 === 更新 === 我已经用 C 写了一个名为“quadraticform”的函数,作为 Matlab 扩展。
// y = quadraticform(A, x)
#include "mex.h" 

/* Input Arguments */
#define A_in prhs[0]
#define x_in prhs[1]

/* Output Arguments */
#define y_out plhs[0] 

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  mwSize mA, nA, n, mx, nx;
  double *A, *x;
  double z, y;
  int i, j, k;

  if (nrhs != 2) { 
      mexErrMsgTxt("Two input arguments required."); 
  } else if (nlhs > 1) {
      mexErrMsgTxt("Too many output arguments."); 
  }

  mA = mxGetM(A_in);
  nA = mxGetN(A_in);
  if (mA != nA)
    mexErrMsgTxt("The first input argument must be a quadratic matrix.");
  n = mA;

  mx = mxGetM(x_in);
  nx = mxGetN(x_in);
  if (mx != n || nx != 1)
    mexErrMsgTxt("The second input argument must be a column vector of proper size.");

  A = mxGetPr(A_in);
  x = mxGetPr(x_in);
  y = 0.0;
  k = 0;
  for (i = 0; i < n; ++i)
  {
    z = 0.0;
    for (j = 0; j < i; ++j)
      z += A[k + j] * x[j];
    z *= x[i];
    y += A[k + i] * x[i] * x[i] + z + z;
    k += n;
  }

  y_out = mxCreateDoubleScalar(y);
}

我将这段代码保存为“quadraticform.c”,并使用Matlab进行编译:

mex -O quadraticform.c

我写了一个简单的性能测试,以比较这个函数与x'Ax的表现:

clear all; close all; clc;

sizes = int32(logspace(2, 3, 25));
nsizes = length(sizes);
etimes = zeros(nsizes, 2); % Matlab vs. C
nrepeats = 100;
h = waitbar(0, 'Please wait...');
for i = 1 : nrepeats
  for j = 1 : nsizes
    n = sizes(j);
    A = randn(n); 
    A = (A + A') / 2;
    x = randn(n, 1);
    if randn > 0
      start = tic;
      y1 = x' * A * x;
      etimes(j, 1) = etimes(j, 1) + toc(start);
      start = tic;
      y2 = quadraticform(A, x);
      etimes(j, 2) = etimes(j, 2) + toc(start);      
    else
      start = tic;
      y2 = quadraticform(A, x);
      etimes(j, 2) = etimes(j, 2) + toc(start);      
      start = tic;
      y1 = x' * A * x;
      etimes(j, 1) = etimes(j, 1) + toc(start);
    end;
    if abs((y1 - y2) / y2) > 1e-10
      error('"x'' * A * x" is not equal to "quadraticform(A, x)"');
    end;
    waitbar(((i - 1) * nsizes + j) / (nrepeats * nsizes), h);
  end;
end;
close(h);
clear A x y;
etimes = etimes / nrepeats;

n = double(sizes);
n2 = n .^ 2.0;
i = nsizes - 2 : nsizes;
n2_1 = mean(etimes(i, 1)) * n2 / mean(n2(i));
n2_2 = mean(etimes(i, 2)) * n2 / mean(n2(i));

figure;
loglog(n, etimes(:, 1), 'r.-', 'LineSmoothing', 'on');
hold on;
loglog(n, etimes(:, 2), 'g.-', 'LineSmoothing', 'on');
loglog(n, n2_1, 'k-', 'LineSmoothing', 'on');
loglog(n, n2_2, 'k-', 'LineSmoothing', 'on');
axis([n(1) n(end) 1e-4 1e-2]);
xlabel('Matrix size, n');
ylabel('Running time (a.u.)');
legend('x'' * A * x', 'quadraticform(A, x)', 'O(n^2)', 'Location', 'NorthWest');

W = 16 / 2.54; H = 12 / 2.54; dpi = 100;
set(gcf, 'PaperPosition', [0, 0, W, H]);
set(gcf, 'PaperSize', [W, H]);
print(gcf, sprintf('-r%d',dpi), '-dpng', 'quadraticformtest.png');

结果非常有趣。无论是 x'*A*x 还是 quadraticform(A,x) 的运行时间都收敛于 O(n^2),但前者的因子更小:

quadraticformtest.png


如果您有Matlab编译器,那么您可以在C中编写您的“quadraticform”函数,并从Matlab中调用它。(Matlab编译器编译MEX文件,在Windows上本质上是一个DLL。)您可以使用我在答案中建议的方法,它非常简单易行,并且随着n的增加按比例扩展为n*(n-1)/2+3*n - kol
你可以不用MATLAB编译器就能做到这一点。 - Sam Roberts
@SamRoberts 是的,你可以在没有Matlab的情况下编译C DLLs,但是我不想费力地处理Matlab矩阵的内部表示。 - kol
@SamRoberts 哦,是的。你说得对。那是很久以前的事了... :) - kol
有趣的是,我写了mex文件,但实际上它比只做x'Ax要慢一点,而这仍然是我找到的最快的方法。C代码(不包括所有额外的检查)在这里:https://gist.github.com/1428255 我猜调用mex文件的开销抵消了潜在的性能提升? - dantswain
显示剩余8条评论

1

MATLAB非常聪明,能够识别和优化某些类型的复合矩阵表达式,我相信(虽然我不能确定)二次形式是它所做的优化之一。

然而,这不是MathWorks通常会记录的内容,因为: a)它通常只在函数内部进行优化,而不是在脚本、命令行或调试中; b)它可能只适用于某些情况,例如对于实数非稀疏A; c)它可能会随着版本的发布而改变,因此他们不希望您依赖它; d)这是使MATLAB如此出色的专有技术之一。

要确认这一点,您可以尝试比较y=x'*A*xB=A*x; y=x'*B的时间。您还可以尝试使用feature('accel','off'),这将关闭大多数这类优化。

最后,如果您联系MathWorks支持,您可能会得到其中一位开发人员的确认是否正在进行优化。


MATLAB足够聪明,能够识别和优化这种类型的表达式。你能以某种方式证明这个说法吗?我们都不知道当Matlab编译和运行这样的表达式时背后发生了什么。 - kol
看看Octave对此的处理也很有趣,我的意思是,是否有任何优化。Octave是开源的,所以如果有人有时间(很多时间...),找出来并不是不可能的。 - kol

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接