MATLAB/Octave - 广义矩阵乘法

10

我希望能够编写一个通用的矩阵乘法函数。基本上,它应该能够执行标准的矩阵乘法,但它应该允许使用任何其他函数来代替两个二元运算符中的积/和。

目标是尽可能高效,既要考虑CPU也要考虑内存。当然,它总会比A*B效率低,但这里的重点在于运算符的灵活性。

读完各种有趣的帖子后,我想到了以下几个命令:

A = randi(10, 2, 3);
B = randi(10, 3, 4);

% 1st method
C = sum(bsxfun(@mtimes, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)
% Alternative: C = bsxfun(@(a,b) mtimes(a',b), A', permute(B, [1 3 2]))

% 2nd method
C = sum(bsxfun(@(a,b) a*b, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)

% 3rd method (Octave-only)
C = sum(permute(A, [1 3 2]) .* permute(B, [3 2 1]), 3)

% 4th method (Octave-only): multiply nxm A with nx1xd B to create a nxmxd array
C = bsxfun(@(a, b) sum(times(a,b)), A', permute(B, [1 3 2]));
C = C2 = squeeze(C(1,:,:)); % sum and turn into mxd

方法1-3的问题在于它们会生成n个矩阵,再使用sum()将它们折叠。方法4更好,因为它在bsxfun内部执行sum(),但bsxfun仍会生成n个矩阵(除了它们大多数是空的,只包含一个非零值向量的和,其余用0填充以匹配尺寸要求)。

我想要的是像第四种方法一样,但不需要用无用的0来节省内存。

有什么想法吗?


1
为什么不尝试使用稀疏矩阵来节省内存分配呢?你可能能够让它工作起来。spfun类似于bsxfun,但是针对稀疏矩阵,因此我假设它在后台也保持了较低的内存使用率。 - MZimmerman6
已经完成,实际上第四种方法应该能够利用稀疏性赚取利润,但不幸的是它不能与Octave一起使用,因为它的bsxfun运算符不友好于稀疏矩阵,所以所有内容都将被存储在内存中。 - gaborous
1
我另外一个问题是,你处理的矩阵有多大,以至于你如此关注内存? - MZimmerman6
@RodyOldenhuis:感谢您的反馈,确实第三和第四个方法只适用于Octave。这是寻找替代方案的另一个原因,因为第四种方法的问题正是我想要解决的问题:输出维度不正确。 - gaborous
@MZimmerman6:对于每个矩阵,其大小在n10E4和n10E5之间(其中n通常在20到6400之间),并且这应该能够处理更大的矩阵。但是它们非常稀疏,因此我可以重度使用内存技巧来提高计算能力。 - gaborous
显示剩余3条评论
4个回答

4
这是您发布的解决方案的稍微优化版本,有一些小改进。
我们检查行数是否大于列数或者反过来,然后根据情况选择将行与矩阵相乘还是矩阵与列相乘(从而进行最少的循环迭代)。

A*B

注意:即使行数少于列数,按行切片并不总是最佳策略;MATLAB数组在内存中以列优先顺序存储,因此按列切片更有效,因为元素是连续存储的。而访问行需要通过步长遍历元素(这不是缓存友好的——考虑空间局部性)。
除此之外,代码应该处理双精度/单精度、实数/复数、完整/稀疏矩阵(以及不能组合的错误)。它还支持空矩阵和零维度。
function C = my_mtimes(A, B, outFcn, inFcn)
    % default arguments
    if nargin < 4, inFcn = @times; end
    if nargin < 3, outFcn = @sum; end

    % check valid input
    assert(ismatrix(A) && ismatrix(B), 'Inputs must be 2D matrices.');
    assert(isequal(size(A,2),size(B,1)),'Inner matrix dimensions must agree.');
    assert(isa(inFcn,'function_handle') && isa(outFcn,'function_handle'), ...
        'Expecting function handles.')

    % preallocate output matrix
    M = size(A,1);
    N = size(B,2);
    if issparse(A)
        args = {'like',A};
    elseif issparse(B)
        args = {'like',B};
    else
        args = {superiorfloat(A,B)};
    end
    C = zeros(M,N, args{:});

    % compute matrix multiplication
    % http://en.wikipedia.org/wiki/Matrix_multiplication#Inner_product
    if M < N
        % concatenation of products of row vectors with matrices
        % A*B = [a_1*B ; a_2*B ; ... ; a_m*B]
        for m=1:M
            %C(m,:) = A(m,:) * B;
            %C(m,:) = sum(bsxfun(@times, A(m,:)', B), 1);
            C(m,:) = outFcn(bsxfun(inFcn, A(m,:)', B), 1);
        end
    else
        % concatenation of products of matrices with column vectors
        % A*B = [A*b_1 , A*b_2 , ... , A*b_n]
        for n=1:N
            %C(:,n) = A * B(:,n);
            %C(:,n) = sum(bsxfun(@times, A, B(:,n)'), 2);
            C(:,n) = outFcn(bsxfun(inFcn, A, B(:,n)'), 2);
        end
    end
end

比较

该函数无疑在整个过程中速度较慢,但对于更大的尺寸,它比内置矩阵乘法慢了几个数量级:

        (tic/toc times in seconds)
      (tested in R2014a on Windows 8)

    size      mtimes       my_mtimes 
    ____    __________     _________
     400     0.0026398       0.20282
     600      0.012039       0.68471
     800      0.014571        1.6922
    1000      0.026645        3.5107
    2000       0.20204         28.76
    4000        1.5578        221.51

mtimes_vs_mymtimes

这是测试代码:
sz = [10:10:100 200:200:1000 2000 4000];
t = zeros(numel(sz),2);
for i=1:numel(sz)
    n = sz(i); disp(n)
    A = rand(n,n);
    B = rand(n,n);

    tic
    C = A*B;
    t(i,1) = toc;
    tic
    D = my_mtimes(A,B);
    t(i,2) = toc;

    assert(norm(C-D) < 1e-6)
    clear A B C D
end

semilogy(sz, t*1000, '.-')
legend({'mtimes','my_mtimes'}, 'Interpreter','none', 'Location','NorthWest')
xlabel('Size N'), ylabel('Time [msec]'), title('Matrix Multiplication')
axis tight

额外内容

为了完整性,以下是另外两种实现广义矩阵乘法的简单方法(如果您想比较性能,请将my_mtimes函数的最后一部分替换为其中任意一种)。我甚至不会费心去发布它们的经过时间 :)

C = zeros(M,N, args{:});
for m=1:M
    for n=1:N
        %C(m,n) = A(m,:) * B(:,n);
        %C(m,n) = sum(bsxfun(@times, A(m,:)', B(:,n)));
        C(m,n) = outFcn(bsxfun(inFcn, A(m,:)', B(:,n)));
    end
end

另一种方法(使用三重循环):

C = zeros(M,N, args{:});
P = size(A,2); % = size(B,1);
for m=1:M
    for n=1:N
        for p=1:P
            %C(m,n) = C(m,n) + A(m,p)*B(p,n);
            %C(m,n) = plus(C(m,n), times(A(m,p),B(p,n)));
            C(m,n) = outFcn([C(m,n) inFcn(A(m,p),B(p,n))]);
        end
    end
end

接下来该尝试什么?

如果您想要更多的性能,您需要转向使用C/C++ MEX文件,以减少解释型MATLAB代码的开销。您仍然可以通过从MEX文件中调用它们来利用优化的BLAS/LAPACK例程(请参阅此帖子的第二部分作为示例)。MATLAB附带Intel MKL库,当涉及到在Intel处理器上进行线性代数计算时,您无法找到更好的替代品。

其他人已经提到了File Exchange上的一些提交,这些提交将通用矩阵例程实现为MEX文件(请参见@natan的回答)。如果您将它们与优化的BLAS库链接,它们尤其有效。


M的开关技巧不错,但是应该是M < N而不是M > N(因为我们想要减少循环次数),并且C(m,:) = outFcn(bsxfun(inFcn, A(m,:), B), 1); 应该改为C(m,:) = outFcn(bsxfun(inFcn, A(:,m), B), 1);以充分利用列主序(对于稀疏矩阵来说更加重要)。另外一个小问题是,你的代码使用了一些在Octave 3.8.1中无法工作的函数(args{:},superiorfloat)。 不过,确实使用了你的代码后速度有所提升,所以你目前的答案最接近正确答案。 - gaborous
@user1121352:1)是的,我的错,应该是M<N。我会修正它。2)不,现在是正确的(我认为你在A(m,:)中错过了转置)。这个想法是将A的行“乘以”矩阵B并进行水平连接。3)我没有用Octave测试过它,但很容易适应代码。语法zeros(.., 'like',X)还没有进入Octave,您可以用类似的zeros(.., class(X))替换它(虽然不会选择稀疏属性)。 - Amro
至于superiorfloat调用,您可以将其替换为手动检查以选择适当的类型,即在single/double之间进行选择(这是由于在MATLAB中数据类型传播的方式在组合不同的数字类时) - Amro
在之前的评论中,我应该说“垂直连接”而不是“水平连接”。我还翻转了代码注释中的描述。现在已经修复了 :/ 我还添加了一条关于按行遍历与按列遍历的说明... - Amro
谢谢您提供的详细信息,但您确定使用A'(:,m)而不是A(m,:)'更好吗?这样A就可以按列切片而不是跨步。同样,C可以转置,然后可以使用C'(:,m)而不是C(m,:)来填充?无论如何,我会授予您奖励,但最后我会尝试使用LAPACK。如果那行不通,我会接受您的答案。 - gaborous
@user1121352:就像我解释的那样,这并不总是很明显。在循环之前和之后转置AC的成本可能会超过按列访问的好处。我想这取决于输入矩阵的大小...例如,我们可以将if条件更改为类似于if 2*M<N的内容,以证明按行切片是合适的(或者您认为其他任何乘数都可以)。最好的方法是运行测试,看看哪个平均运行时间更好 :) - Amro

3
为什么不利用bsxfun接受任意函数的能力?
C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1);

这里

  • f外部函数(对应于矩阵乘法中的sum)。它应该接受任意大小的3D数组mxnxp,并沿着其列操作以返回一个1xmxp数组。
  • g内部函数(对应于矩阵乘法中的product)。与bsxfun一样,它应该接受两个大小相同的列向量或一个列向量和一个标量作为输入,并将一个与输入大小相同的列向量作为输出。

这在Matlab中可以工作。我还没有在Octave中测试过。


示例1:矩阵乘法:

>> f = @sum;   %// outer function: sum
>> g = @times; %// inner function: product
>> A = [1 2 3; 4 5 6];
>> B = [10 11; -12 -13; 14 15];
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
    28    30
    64    69

检查:

>> A*B
ans =
    28    30
    64    69
例子2:考虑上面这两个矩阵,它们具有以下特征:
>> f = @(x,y) sum(abs(x));     %// outer function: sum of absolute values
>> g = @(x,y) max(x./y, y./x); %// inner function: "symmetric" ratio
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
   14.8333   16.1538
    5.2500    5.6346

检查:手动计算 C(1,2)

>> sum(abs( max( (A(1,:))./(B(:,2)).', (B(:,2)).'./(A(1,:)) ) ))
ans =
   16.1538

谢谢您详细的回答,但它与Divakar上面提出的问题非常相似,问题在于内存爆炸,因为bsxfun无法对每个单例扩展执行求和乘积(而不是生成整个扩展,然后在第三维中求和)。例如,在Octave上尝试使用此矩阵:A = randi(2, 1000, 500)-1; 这将产生索引溢出。此外,这对于稀疏矩阵不起作用,因为如果稀疏,则无法排列到第三维。 - gaborous

1

不详细讨论,有一些工具如mtimesxMMX是快速的通用矩阵和标量操作程序。您可以查看它们的代码并根据需求进行调整。这很可能比Matlab的bsxfun更快。


3
在这里,使用C/C++代码变得更本地化绝对是正确的方法。 - Amro
我同意这将是速度和内存方面最好的解决方案,但我不确定您是否可以将函数作为GMM的参数传递,尽管我猜想可以在mex文件中定义最常见的运算符,并传递一个字符串作为参数进行选择。此外,还有一个LAPACK直接调用FEX,可能适合这里的需求,甚至不需要重写任何MEX代码:http://www.mathworks.com/matlabcentral/fileexchange/16777-lapack/content/lapack.m - gaborous
我尝试了LAPACK FEX库,但它不能与Octave一起使用,我不知道最新的MatLab版本是否可以。我发现了另一个有趣的项目:Mc2For项目,旨在从MatLab函数源代码生成Fortran 95代码:http://www.sable.mcgill.ca/mclab/matlab_fortran.html - gaborous

0
经过对bsxfun等几个处理函数的考察,似乎不可能使用它们进行直接的矩阵乘法运算(我的意思是临时乘积不会存储在内存中,而是立即求和,然后再处理其他的乘积和求和),因为它们的输出尺寸是固定的(要么与输入尺寸相同,要么通过bsxfun进行单元素扩展得到两个输入维度的笛卡尔积)。不过,可以通过一些小技巧欺骗Octave(这种方法在MatLab中无效,因为MatLab会检查输出尺寸):
C = bsxfun(@(a,b) sum(bsxfun(@times, a, B))', A', sparse(1, size(A,1)))
C = bsxfun(@(a,b) sum(bsxfun(@times, a, B))', A', zeros(1, size(A,1), 2))(:,:,2)

然而不要使用它们,因为输出的值不可靠(Octave 可能会损坏甚至删除它们并返回 0!)。

所以现在我只实现了一个半向量化版本,这是我的函数:

function C = genmtimes(A, B, outop, inop)
% C = genmtimes(A, B, inop, outop)
% Generalized matrix multiplication between A and B. By default, standard sum-of-products matrix multiplication is operated, but you can change the two operators (inop being the element-wise product and outop the sum).
% Speed note: about 100-200x slower than A*A' and about 3x slower when A is sparse, so use this function only if you want to use a different set of inop/outop than the standard matrix multiplication.

if ~exist('inop', 'var')
    inop = @times;
end

if ~exist('outop', 'var')
    outop = @sum;
end

[n, m] = size(A);
[m2, o] = size(B);

if m2 ~= m
    error('nonconformant arguments (op1 is %ix%i, op2 is %ix%i)\n', n, m, m2, o);
end


C = [];
if issparse(A) || issparse(B)
    C = sparse(o,n);
else
    C = zeros(o,n);
end

A = A';
for i=1:n
    C(:,i) = outop(bsxfun(inop, A(:,i), B))';
end
C = C';

end

使用稀疏矩阵和普通矩阵进行测试,性能差距在稀疏矩阵中要小得多(慢3倍),而在普通矩阵中则要慢约100倍。

我认为这比bsxfun实现要慢,但至少不会溢出内存:

A = randi(10, 1000);
C = genmtimes(A, A');

如果有更好的选择,我仍在寻找更好的替代方案!


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接