在Matlab中高效实现张量点积

3
在MATLAB中,我有一系列2x2矩阵堆叠成3D张量,我想对每个矩阵实例执行矩阵乘法。因此,我的C = A * B定义为:
C_ijk = sum(a_ilk * b_ljk, over all l)

我的当前实现看起来像这样

    function mats = mul3D(A, B)
        % given a list of 2D matrices (e.g. rotation matrices) applies
        % the matrix product for each instance along the third dimension
        % mats(:,:,i) = A(:,:,i) * B(:,:,i) for all i
        % for this to succeed matrix dimensions must agree.
        mats = zeros(size(A,1), size(B,2), size(B,3));
        for i=1:size(B, 3)
            mats(:,:,i) = A(:,:,i) * B(:,:,i);
        end
    end

这段代码很容易理解,但我记得有人说MATLAB不喜欢for循环。

你能否想到更好的实现方式,在不增加内存消耗的情况下提高速度?我的代码在运行时有大约50%的时间花费在这个for循环中。

编辑

感谢您的建议。不幸的是,我无法向第三方代码引入新的依赖项。

根据您的问题,我想到了利用张量的2 x 2 x n结构。我的最新实现如下:

    function mats = mul3D(A, B)
        % given a list of 2D matrices (e.g. rotation matrices) applies
        % the matrix product for each instance along the third dimension
        % mats(:,:,i) = A(:,:,i) * B(:,:,i) for all i
        % for this to succeed matrix dimensions must agree.
        mats = zeros(size(A,1), size(B,2), size(B,3));

        mats(1,1,:) = A(1,1,:) .* B(1,1,:) + A(1,2,:) .* B(2,1,:);
        mats(2,1,:) = A(2,1,:) .* B(1,1,:) + A(2,2,:) .* B(2,1,:);
        if(size(mats,2) > 1)
            mats(1,2,:) = A(1,1,:) .* B(1,2,:) + A(1,2,:) .* B(2,2,:);
            mats(2,2,:) = A(2,1,:) .* B(1,2,:) + A(2,2,:) .* B(2,2,:);
        end
    end

欢迎提出更多建议!


您IP地址为143.198.54.68,由于运营成本限制,当前对于免费用户的使用频率限制为每个IP每72小时10次对话,如需解除限制,请点击左下角设置图标按钮(手机用户先点击左上角菜单按钮)。 - Luis Mendo
“A”和“B”的典型形状是什么? - Divakar
我在这里找到了优化的库:链接。但是我在构建时遇到了问题。 - Rotem
以下代码 mats(:,:,i) = A(:,:,i) * B(:,:,i) 被定义为“点积”吗? - Rotem
1个回答

1
我建议您使用mtimesx
请参考这里: https://www.mathworks.com/matlabcentral/answers/62382-matrix-multiply-slices-of-3d-matricies mtimesx使用优化的mex文件来执行"Matrix multiply slices of 3d Matricies"mtimesx使用BLAST库 (BLAST库是Matlab安装的一部分)。
从这里下载mtimesx源代码:http://www.mathworks.com/matlabcentral/fileexchange/25977-mtimesx-fast-matrix-multiply-with-multi-dimensional-support 我在Matlab r2014b中构建mex文件时遇到了问题。
问题是Matlab r2014a以上版本缺少文件mexopts.bat
mex构建脚本使用mexopts.bat
我通过下载mexopts.bat解决了这个问题。
我正在使用Visual Studio 2010编译器,并在此处找到了匹配的mexopts.bathttp://www.dynare.org/DynareWiki/ConfigureMatlabWindowsForMexCompilation
我将mexopts.bat复制到本地文件夹:c:\Users\Rotem\AppData\Roaming\MathWorks\MATLAB\R2014b\
之后,mtimesx表现得非常好...
使用mex文件应该比使用for循环快得多。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接