在MatLab中,将矩阵的列与3D矩阵的2D矩阵切片相乘

4
基本上,我想执行以下计算:
    G is m x n x k
    S is n x k

    Answer=zeros(m,d)
    for Index=1:k
        Answer(:,Index)=G(:,:,Index)*S(:,Index)
    end

因此,答案是一个矩阵,其列是将3D矩阵的每一层与另一个矩阵的列相乘的结果。

这似乎是一种简单直接的操作,我希望知道在Matlab中是否有本地或矢量化(或至少更快)的方式来执行这种计算。谢谢。

3个回答

2
尝试使用来自Matlab文件交换的mtimesx。这是我迄今为止发现的执行此类n维数组乘法的最佳(快速/高效)工具,因为它使用了mex。我认为你也可以使用bsxfun, 但我的Matlab技术还不足以胜任这种任务。
你有一个大小为m x n x km x k 的输入,想要产生一个n x k的输出。 mtimesx可以将i x j x kj x r x k这样的输入相乘,得到i x r x k的输出。
将您的问题转化为`mtimesx`形式,假设`G`是`m x n x k`的矩阵,并将`S`扩展为`n x 1 x k`的矩阵。然后,`mtimesx(G,S)`将是`m x 1 x k`的矩阵,然后可以压缩为`m x k`的矩阵。
m=3; 
n=4; 
k=2;
G=rand(m,n,k);
S=rand(n,k);

% reshape S
S2=reshape(S,n,1,k);

% do multiplication and flatten mx1xk to mxk
Ans_mtimesx = reshape(mtimesx(G,S2),m,k)

% try loop method to compare
Answer=zeros(m,k);
for Index=1:k
    Answer(:,Index)=G(:,:,Index)*S(:,Index);
end

% compare
norm(Ans_mtimesx-Answer)
% returns 0.

所以,如果你想要一个一行的代码,你可以这样做:
Ans = reshape(mtimesx(G,reshape(S,n,1,k)),m,k)

顺便说一下,如果你在Matlab Newsreader论坛上发布你的问题,那里会有很多高手竞相给你提供比我更优雅或更有效的答案!

谢谢,这看起来可以满足我的需求。我会在能够测试时再回来查看。 - deftfyodor
如果每个k切片中m变化怎么办? - siamii
请参见http://stackoverflow.com/questions/15248109/how-to-do-m-n-k-n-k-m-k-in-matlab。 - siamii

1

这是使用bsxfun()函数的版本。如果A是一个m×n矩阵,x是一个n×1向量,则可以计算出A*x:

sum(bsxfun(@times, A, x'), 2)

操作permute(S,[3 1 2])将获取S的列,并将它们分布在第三维度上作为行。 [3 1 2]是S维度的排列。
因此,sum(bsxfun(@times, G, permute(S,[3 1 2])), 2)可以得出答案,但结果仍在第三维度中。为了按照所需格式呈现结果,需要进行另一个permute操作。
permute(sum(bsxfun(@times, G, permute(S, [3 1 2])), 2), [1 3 2])

0

你可以将三维矩阵表示为二维分块对角矩阵,每一层都是一个对角块。在这种情况下,应该将二维矩阵表示为包含堆叠列的向量。如果矩阵很大,声明它为稀疏矩阵。


可以这样做,但我并不完全相信它会提高速度。稍后我会试一下。 - deftfyodor

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接