矩阵乘法:Strassen算法 vs. 标准算法

4

我试图使用C ++实现 Strassen算法进行矩阵乘法,但结果并非我所期望的。如您所见,Strassen始终需要比标准实现更多的时间,并且仅在维数为2的幂时才与标准实现一样快。出了什么问题? alt text

matrix mult_strassen(matrix a, matrix b) {
if (a.dim() <= cut)
    return mult_std(a, b);

matrix a11 = get_part(0, 0, a);
matrix a12 = get_part(0, 1, a);
matrix a21 = get_part(1, 0, a);
matrix a22 = get_part(1, 1, a);

matrix b11 = get_part(0, 0, b);
matrix b12 = get_part(0, 1, b);
matrix b21 = get_part(1, 0, b);
matrix b22 = get_part(1, 1, b);

matrix m1 = mult_strassen(a11 + a22, b11 + b22); 
matrix m2 = mult_strassen(a21 + a22, b11);
matrix m3 = mult_strassen(a11, b12 - b22);
matrix m4 = mult_strassen(a22, b21 - b11);
matrix m5 = mult_strassen(a11 + a12, b22);
matrix m6 = mult_strassen(a21 - a11, b11 + b12);
matrix m7 = mult_strassen(a12 - a22, b21 + b22);

matrix c(a.dim(), false, true);
set_part(0, 0, &c, m1 + m4 - m5 + m7);
set_part(0, 1, &c, m3 + m5);
set_part(1, 0, &c, m2 + m4);
set_part(1, 1, &c, m1 - m2 + m3 + m6);

return c; 
}


程序
matrix.h http://pastebin.com/TYFYCTY7
matrix.cpp http://pastebin.com/wYADLJ8Y
main.cpp http://pastebin.com/48BSqGJr

使用 g++ 编译 main.cpp 和 matrix.cpp,并将输出保存为 matrix,使用 -O3 优化选项。


如果你想让别人帮助你,应该提供当前结果和期望结果(也许还有乘法函数)。放整个代码太多了。另外,如果问题是关于作业的,请添加作业标志。 - Patrice Bernassola
这不是作业,我只是对实现 Strassen 算法感兴趣,因为它应该更快。标准算法的时间复杂度为 O(n^3),而 Strassen 的时间复杂度为 O(n^2.8),因为它需要少一次乘法。 - multiholle
5个回答

8

一些想法:

  • 你是否优化过它来考虑非二次幂大小的矩阵填充为零?我认为该算法假定您不会乘以这些项,这就是为什么在2^n和2^(n+1)-1之间运行时间恒定的平坦区域。通过不乘以你知道是零的项,你应该能够改进这些区域。或者Strassen只适用于2^n大小的矩阵。
  • 请考虑“大”矩阵是任意的,而且该算法只比朴素情况稍微好一点,O(N^3)对O(N^2.8)。你可能需要尝试更大的矩阵才能看到可测量的增益。例如,我进行了一些有限元建模,其中10,000x10,000矩阵被认为是“小”的。从你的图表中很难看出来,但在Stassen情况下,511可能会更快。
  • 尝试使用各种优化级别进行测试,包括没有任何优化。
  • 这个算法似乎假设乘法比加法昂贵得多。40年前它首次开发时当然是这样的,但我认为在更现代的处理器上,加和乘的差距已经变小了。这可能会降低算法的效果,似乎减少乘法但增加加法。
  • 您是否查看其他Strassen实现以获取想法?尝试对已知的良好实现进行基准测试,以确定您可以获得多少更快的速度。

使用 Strassen 算法时,通过修改矩阵的存储顺序(使用 Z-Order)也可以帮助加快速度,使内存访问更加缓存友好。 - void-pointer

2

好的,我虽然不是这个领域的专家,但在这里可能有其他问题需要处理,而不仅仅是处理速度。首先,Strassen方法使用了更多的堆栈和更多的函数调用,这会增加内存移动。您的堆栈越大,惩罚就越大,因为它需要从操作系统请求更大的帧。此外,您使用了动态分配,这也是一个问题。

尝试使用固定大小(使用模板参数)的矩阵类?这将至少解决分配问题。

注意:我不确定它是否与您的代码正常工作。您的矩阵类使用指针,但没有复制构造函数或赋值运算符。最后,您还泄漏了内存,因为您没有析构函数...


2
Strassen算法的大O表示法为O(N ^ log 7),而普通算法为O(N ^ 3),即以2为底数的log 7,略小于3。这是所需进行的乘法次数。它假设您没有任何其他成本,并且只有在N足够大时才会更快,而您可能并非如此。您的许多实现都是创建许多子矩阵,我的猜测是您存储它们的方式可能需要每次分配内存和复制。如果可以使用某种“切片”矩阵和逻辑转置矩阵,则可以帮助您优化过程中可能最慢的部分。

1

我真的很惊讶我的Stassen乘法实现速度有多快:

https://github.com/wcochran/strassen_multiplier/blob/master/mm.c

当n=1024时,我的计算机速度提高了近16倍。我想唯一能解释这么大的速度提升的方法就是我的算法更加缓存友好——也就是说,它关注于矩阵的小块,因此数据更加局部化。
您C++实现中的开销可能太高了——编译器生成了比实际必要的临时变量更多。我的实现尽可能地重用内存来尝试最小化这个问题。

对于1024 x 1024的矩阵,您需要使用分块算法 - 因为每次乘法都会导致一个平均缓存未命中,所以朴素实现将非常慢。您真的需要调查一下如果您取n = 1023或1025,时间如何变化,并相应地更改矩阵的布局。 - gnasher729
代码链接已损坏。 - CubicleSoft

0
有点冒险,但你考虑过编译器可能通过优化来优化标准乘法吗?你能关闭优化吗?

关闭优化会增加时间,但是会得到相同的图形。 - multiholle
嗯,我觉得我无法再帮助你了。在你的 Strassen 实现中,似乎有很多函数调用正在访问主内存。也许这是瓶颈所在?朴素的乘法会相当地利用寄存器。我建议使用 BLAS 库进行矩阵乘法。 - Eamorr
2
关闭优化有点违背了任何基准测试的意义,因为它应该提供在实际生活中有用的结果。 - Kos
1
哇,这个负反馈是怎么回事?我没有看到你提出任何建议。 - Eamorr

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接