将一个三维矩阵与一个二维矩阵相乘。

27
假设我有一个 AxBxC 矩阵 X 和一个 BxD 矩阵 Y。是否有一种非循环方法,可以将 Y 与每个 C AxB 矩阵相乘?

5
为什么要费这个劲呢?我看了Gnovice(正确的)解决方案,需要花费相当长的时间才能理解它的作用。然后我看了Zaid的方案,立即就明白了。如果存在性能差异,也要考虑维护成本。 - MatlabDoug
2
这不是关于性能或可读性的问题 - 只是出于好奇心,因为我知道可以单独操作每个3D矩阵,但无法弄清楚如何操作。我知道Gnovice的解决方案比Zaid的“解决方案”和Amro的解决方案慢得多,但正如我所说,这不是重点。 - Jacob
1
现在你完全把我搞糊涂了... 你到底想要什么? - Zaid
我需要一种非循环的方法,可以将C个AxB矩阵中的每一个都与Y相乘,例如Amro和GNovice的解决方案。 - Jacob
6
@Jacob:1. gnovice的解决方案并不比amro的慢。2. gnovice的解决方案使用了cellfun,它是一个包装循环的函数。所以你可以把Zaid的解决方案做成一个叫做prod3D.m的函数,调用它,就可以得到一个非循环的方法来乘以X和Y。3. 别忘了80%的软件成本是维护成本。 - Mikhail Poda
不要忘记,80%的软件成本是维护费用。- 你太棒了,@Mikhail! - user649198
10个回答

18

个人偏好,我喜欢我的代码尽可能简洁易读。

以下是我会做的做法,但不符合你的“无循环”要求:

for m = 1:C

    Z(:,:,m) = X(:,:,m)*Y;

end

这将导致一个 A x D x C 的矩阵 Z

当然,你可以通过使用 Z = zeros(A,D,C); 来预先分配Z以加快速度。


4
-1:因为无论您免责声明如何,这都不是一个真正的解决方案。如果您对简洁性或易读性有任何意见,请将它们留作评论。 - Jacob
6
+1 是因为它比 gnovice 和 amro 的好解决方案更快。 - Ramashalanka
但请先使用Z = zeros([A D C]);预分配Z! (为可读性加1) - Floris

15
你可以使用函数NUM2CELL将矩阵X分解成单元数组,然后再使用CELLFUN函数在单元格之间操作,一行代码即可完成。
Z = cellfun(@(x) x*Y,num2cell(X,[1 2]),'UniformOutput',false);

结果Z是一个1×C的单元数组,其中每个单元包含一个A×D的矩阵。如果你想让Z成为一个A×D×C的矩阵,你可以使用CAT函数:

Z = cat(3,Z{:});

注意:我的旧解决方法使用了 MAT2CELL 而不是 NUM2CELL,这种方法不够简洁。

[A,B,C] = size(X);
Z = cellfun(@(x) x*Y,mat2cell(X,A,B,ones(1,C)),'UniformOutput',false);

2
使用此解决方案,循环位于cellfun内部。但是,在大矩阵上它仍然比amro提供的解决方案快10%(在MATLAB即将耗尽内存之前)。 - Mikhail Poda
我很好奇我得到的两个踩。无论你是否喜欢这个答案,它通过避免显式使用for循环来回答问题。 - gnovice
2
哇,谁会想到一个简单的问题会引起这么多争议呢? - Jacob
1
@Jacob:是的,看起来它引发了一些争议。由于我之前见过你回答MATLAB问题,所以我认为你已经知道如何使用循环(最直接的方法)来完成这个问题。我只是假设你是出于好奇想知道其他方法也可以完成它。 - gnovice
@gnovice:你猜对了。我正在寻找一种类似于bsxfun的方法来完成这个任务,而你的cellfun实现正好符合要求。我原以为“非循环”条件足以清楚地表达我的意图,看来并不是! - Jacob

8

这里有一个一行解决方案(如果您想分成第三个维度,则为两行):

A = 2;
B = 3;
C = 4;
D = 5;

X = rand(A,B,C);
Y = rand(B,D);

%# calculate result in one big matrix
Z = reshape(reshape(permute(X, [2 1 3]), [A B*C]), [B A*C])' * Y;

%'# split into third dimension
Z = permute(reshape(Z',[D A C]),[2 1 3]);

因此现在:Z(:,:,i) 包含了 X(:,:,i) * Y 的结果。
解释: 上述内容可能看起来很混乱,但是想法很简单。 首先,我以 X 的第三维为起点,在第一个维度上进行垂直连接:
XX = cat(1, X(:,:,1), X(:,:,2), ..., X(:,:,C))

...难点在于C是一个变量,因此您无法使用catvertcat来概括该表达式。接下来,我们将其乘以Y

ZZ = XX * Y;

最后,我将它再次拆分成第三维:
Z(:,:,1) = ZZ(1:2, :);
Z(:,:,2) = ZZ(3:4, :);
Z(:,:,3) = ZZ(5:6, :);
Z(:,:,4) = ZZ(7:8, :);

所以您可以看到,它只需要进行一次矩阵乘法,但是您需要在之前和之后对矩阵进行重塑


谢谢!我希望有一个类似于bsxfun的解决方案,但这看起来很有趣。 - Jacob
不需要。正如我所添加的解释所示,只需要通过重新排列矩阵的形状来准备它,这样简单的乘法就足够了。 - Amro
不错的解决方案,但由于重塑操作可能会导致内存溢出。 - gaborous
1
正如OP在评论中提到的那样,这里的动机是探索替代解决方案(为了好玩),而不是生产更快或更易读的代码... 在生产代码中,我会坚持使用直接的for循环 :) - Amro

6
我遇到了同样的问题,并试图找出最有效的方法。我大致看到有三种方法可供选择,除了使用外部库(例如mtimesx):
  1. 循环遍历3D矩阵的切片
  2. 使用repmat和permute函数
  3. 使用cellfun进行矩阵乘法
我最近比较了这三种方法,看哪一种最快。我的直觉是第二种方法会是赢家。以下是代码:
% generate data
A = 20;
B = 30;
C = 40;
D = 50;

X = rand(A,B,C);
Y = rand(B,D);

% ------ Approach 1: Loop (via @Zaid)
tic
Z1 = zeros(A,D,C);
for m = 1:C
    Z1(:,:,m) = X(:,:,m)*Y;
end
toc

% ------ Approach 2: Reshape+Permute (via @Amro)
tic
Z2 = reshape(reshape(permute(X, [2 1 3]), [A B*C]), [B A*C])' * Y;
Z2 = permute(reshape(Z2',[D A C]),[2 1 3]);
toc


% ------ Approach 3: cellfun (via @gnovice)
tic
Z3 = cellfun(@(x) x*Y,num2cell(X,[1 2]),'UniformOutput',false);
Z3 = cat(3,Z3{:});
toc

这三种方法输出的结果都一样(太好了!),但令人惊讶的是,循环方式最快:

Elapsed time is 0.000418 seconds.
Elapsed time is 0.000887 seconds.
Elapsed time is 0.001841 seconds.

请注意,从一次试验到另一次试验,时间可能会有很大的差异,有时候(2)是最慢的。这些差异在处理更大的数据时变得更加明显。但是对于非常大的数据,(3)胜过(2)。循环方法仍然是最好的。
% pretty big data...
A = 200;
B = 300;
C = 400;
D = 500;
Elapsed time is 0.373831 seconds.
Elapsed time is 0.638041 seconds.
Elapsed time is 0.724581 seconds.

% even bigger....
A = 200;
B = 200;
C = 400;
D = 5000;
Elapsed time is 4.314076 seconds.
Elapsed time is 11.553289 seconds.
Elapsed time is 5.233725 seconds.

但是循环方法 可能会 比(2)慢,特别是当循环的维度比其他维度大得多时。

A = 2;
B = 3;
C = 400000;
D = 5;
Elapsed time is 0.780933 seconds.
Elapsed time is 0.073189 seconds.
Elapsed time is 2.590697 seconds.

因此,在这种(可能极端)情况下,(2)获胜的优势很大。也许没有一种方法是在所有情况下都最优的,但是循环仍然非常好,并且在许多情况下是最好的。从可读性角度来看,它也是最佳的。继续使用循环!


1
我强烈推荐您使用Matlab的MMX工具箱。它可以尽可能快地进行n维矩阵乘法运算。
MMX的优点如下:
  1. 易于使用。
  2. 可以乘以n维矩阵(实际上可以乘以2-D矩阵数组)
  3. 执行其他矩阵运算(转置、二次乘法、Chol分解等)
  4. 它使用C编译器和多线程计算来加速。
对于这个问题,您只需要编写以下命令:
C=mmx('mul',X,Y);

这里有一个所有可能方法的基准。更详细的内容请参考此问题

    1.6571 # FOR-loop
    4.3110 # ARRAYFUN
    3.3731 # NUM2CELL/FOR-loop/CELL2MAT
    2.9820 # NUM2CELL/CELLFUN/CELL2MAT
    0.0244 # Loop Unrolling
    0.0221 # MMX toolbox  <===================

1
为了回答这个问题并提高可读性,请参见:
  • ndmult,作者为ajuanpi(Juan Pablo Carbajal),2013年,GNU GPL

输入

  • 2个数组
  • dim

示例

 nT = 100;
 t = 2*pi*linspace (0,1,nT)’;

 # 2 experiments measuring 3 signals at nT timestamps
 signals = zeros(nT,3,2);
 signals(:,:,1) = [sin(2*t) cos(2*t) sin(4*t).^2];
 signals(:,:,2) = [sin(2*t+pi/4) cos(2*t+pi/4) sin(4*t+pi/6).^2];

 sT(:,:,1) = signals(:,:,1)’;
 sT(:,:,2) = signals(:,:,2)’;
   G = ndmult (signals,sT,[1 2]);

来源

原始来源。我添加了内联注释。

function M = ndmult (A,B,dim)
  dA = dim(1);
  dB = dim(2);

  # reshape A into 2d
  sA = size (A);
  nA = length (sA);
  perA = [1:(dA-1) (dA+1):(nA-1) nA dA](1:nA);
  Ap = permute (A, perA);
  Ap = reshape (Ap, prod (sA(perA(1:end-1))), sA(perA(end)));

  # reshape B into 2d
  sB = size (B);
  nB = length (sB);
  perB = [dB 1:(dB-1) (dB+1):(nB-1) nB](1:nB);
  Bp = permute (B, perB);
  Bp = reshape (Bp, sB(perB(1)), prod (sB(perB(2:end))));

  # multiply
  M = Ap * Bp;

  # reshape back to original format
  s = [sA(perA(1:end-1)) sB(perB(2:end))];
  M = squeeze (reshape (M, s));
endfunction

1

不行。有几种方法,但它总是以循环、直接或间接的方式出现。

只是为了满足我的好奇心,你为什么要那样做呢?


2
为什么我要不用循环来做呢?只是老习惯了。虽然MATLAB现在有JITA进行循环优化,但我尽可能避免使用它们——而且我强烈感觉可以不用循环来解决这个问题。 - Jacob
1
是的,好的,我能理解那个。 (相反,有时我会在循环中做一些可以不用循环完成的事情,因为我发现这样更容易阅读 <-- :(老习惯了,也是 :) - Rook

1
我想分享我对以下问题的解答:
1)如何制作两个张量(任意阶张量)的张量积;
2)沿任意维度将两个张量收缩。
以下是我编写的子程序,用于完成第一和第二个任务:
1)张量积:
function [C] = tensor(A,B)
   C = squeeze( reshape( repmat(A(:), 1, numel(B)).*B(:).' , [size(A),size(B)] ) );
end

2) 收缩运算: 这里的A和B是要沿着i和j维度进行收缩运算的张量。当然,这些维度的长度应该相等。代码中没有对此进行检查(这会使代码变得复杂),但除此之外,它运行良好。

   function [C] = tensorcontraction(A,B, i,j)
      sa = size(A);
      La = length(sa);
      ia = 1:La;
      ia(i) = [];
      ia = [ia i];

      sb = size(B);
      Lb = length(sb);
      ib = 1:Lb;
      ib(j) = [];
      ib = [j ib];

      % making the i-th dimension the last in A
      A1 = permute(A, ia);
      % making the j-th dimension the first in B
      B1 = permute(B, ib);

      % making both A and B 2D-matrices to make use of the
      % matrix multiplication along the second dimension of A
      % and the first dimension of B
      A2 = reshape(A1, [],sa(i));
      B2 = reshape(B1, sb(j),[]);

      % here's the implicit implication that sa(i) == sb(j),
      % otherwise - crash
      C2 = A2*B2;

      % back to the original shape with the exception
      % of dimensions along which we've just contracted
      sa(i) = [];
      sb(j) = [];
      C = squeeze( reshape( C2, [sa,sb] ) );
   end

任何批评意见?

非常感谢您分享这些代码。我测试了第一个,它的表现非常出色(我将其与使用嵌套for循环的代码进行了比较!)。您能否解释一下第一个代码背后的逻辑?为什么要在最后进行转置?再次感谢! - Shahram Khazaie
1
好的,最后的转置是为了对齐两个向量的维度。我们这里不应用矩阵乘法(即*操作),而是采用逐元素相乘(即.*操作),因此我们的维度应该完全相同;在这种情况下,我们将两行长度相同的向量相乘。B(:)表示一列,因此需要进行转置以使其也成为一行。 - Dan

0

我会考虑使用递归,但那是你能做的唯一的非循环方法。


0
你可以“展开”循环,即按顺序写出循环中会发生的所有乘法。

1
假设C是变量..等等。 - Jacob

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接