为什么我的Strassen矩阵乘法很慢?

14
我用C++编写了两个矩阵乘法程序: 普通的MM (源代码) 和 Strassen's MM (源代码),两者都适用于大小为2^k x 2^k(换句话说,是偶数大小的正方形矩阵)的矩阵。
结果非常糟糕。对于1024 x 1024矩阵,普通的MM需要46.381秒,而Strassen's MM需要1484.303秒(25分钟!)。
我尽可能保持代码简单。在网上找到的其他Strassen's MM示例与我的代码并没有太大区别。一个明显的问题是Strassen's代码缺少切换回普通MM的截止点。
我的Strassen's MM代码还有哪些问题?谢谢!
源代码的直接链接为: http://pastebin.com/HqHtFpq9 http://pastebin.com/USRQ5tuy 编辑1。 首先,非常感谢提供的许多好建议。我实现了更改(保留了所有代码),添加了截止点。 具有截止点的2048x2048矩阵的MM已经得到了良好的结果。 普通的MM:191.49秒 Strassen's MM:112.179秒 显著改善。 这些结果是在使用Visual Studio 2012的古老Lenovo X61 TabletPC上,使用Intel Centrino处理器获得的。 我将进行更多检查(以确保我获得了正确的结果),并将发布结果。

1
@LuchianGrigore:哦,这很微妙。很可能也是问题的重要部分。实际上可能比我发现的问题更严重。 - Omnifarious
1
@Mysticial 我认为其中一个算法是缓存无感知的,因为维度被固定为2^k。但我可能是错的。 - Luchian Grigore
1
@LuchianGrigore:一个算法非常简单(并且不是缓存无关的)。另一个算法虽然也不是缓存无关的,但应该要快得多。 - Omnifarious
1
@Omnifarious 我故意将我的代码设计为二的幂,否则代码的一部分需要被重写,因为矩阵在除法后不会是偶数,需要填充(零行)。 - newprint
3
好的,那么原因就是在这里解释了:https://dev59.com/K2gu5IYBdhLWcg3wUljt#11413856 - Luchian Grigore
显示剩余8条评论
2个回答

27

斯特拉森算法存在一个明显的问题 - 我没有截止点来切换到常规矩阵乘法。

可以说,递归到1点是问题的主要部分(如果不是整个问题)。试图猜测其他性能瓶颈而不解决这个问题几乎是无意义的,因为它会带来巨大的性能损失(换句话说,你在比较苹果和橙子)。

正如评论中讨论的那样,缓存对齐可能会有影响,但影响不至于如此之大。此外,缓存对齐可能会对常规算法产生更多负面影响,而不是斯特拉森算法,因为后者是缓存无关的。

void strassen(int **a, int **b, int **c, int tam) {

    // trivial case: when the matrix is 1 X 1:
    if (tam == 1) {
            c[0][0] = a[0][0] * b[0][0];
            return;
    }

太小了。虽然Strassen算法的复杂度更小,但它的大O常数要大得多。首先,你从头到尾都需要函数调用开销,直到剩下1个元素。

这类似于使用归并或快速排序,并将递归一直执行到一个元素。为了高效,当大小变小时,您需要停止递归并返回经典算法。

在快速/归并排序中,您会退回到低开销的O(n²)插入或选择排序。在这里,您将回到正常的O(n³)矩阵乘法。


回退到经典算法的阈值应该是可调整的,这个阈值可能会因硬件和编译器优化代码的能力而异。

对于像Strassen乘法这样只比经典的O(n³)快O(2.8074)的情况,如果此阈值非常高(数千个元素?),就不要感到惊讶。


在某些应用程序中,可能有许多算法,每个算法的复杂度都在降低,但大O却在增加。结果是,在不同的大小上,多种算法成为最优。

大整数乘法就是一个臭名昭著的例子:

*请注意,这些示例阈值是近似值,可以大幅变化,通常超过10倍。


3
这是我喜爱StackOverflow的原因之一。通过一个问题,我看到了现实世界中那些微妙的影响如何被放大并以明显的方式展示出来,这可以引起性能问题。当然,很可能这个答案导致了应该更快的算法变慢。 - Omnifarious
@Omnifarious 啊...在这种情况下,为了打破对齐而牺牲局部性实际上可以提高性能。 - Mysticial
@Mysticial:是的,我也在考虑这个问题。不过间接寻址也是一个问题。因此,所有竞争效应的确切结果并不清楚。我的实证测试表明,对于最不缓存友好的天真算法,连续分配数组可以提高50%的性能。 - Omnifarious
@tmyklebu 啊,这就是为什么我给了一个10的因子。这些数字大致上是y-cruncher中使用的值。GMP不使用浮点FFT,因为他们担心舍入误差。所以当你考虑到浮点FFT时,它将SSA(GMP使用的)的阈值推高到数十亿。在100,000位数左右之后,y-cruncher使用一整套未发布的算法。所以我实际上并不知道SSA和NTT的确切阈值,除了“远远超过数十亿”-因为我从未测试过。 - Mysticial
这是我曾经见过的最好的答案之一!对我来说,它就像黑暗中的一束光!非常感谢! - user3692586
显示剩余7条评论

3
所以,可能还有更多问题,但你的第一个问题是你正在使用指向数组的指针数组。由于您使用的是2的幂次方的数组大小,因此这对于将长整数数组折叠成行并连续分配元素并使用整数除法会产生特别大的性能影响。

总之,这是我猜测的第一个问题。正如我所说,可能还有更多问题,我会在发现它们时添加到这个答案中。

编辑:这可能仅对问题做出了很小的贡献。问题可能是Luchian Grigore所提到的与二次幂缓存行争用问题有关。

我验证了我的担忧对于天真算法是有效的。如果数组是连续的,则天真算法的时间减少了近50%。这里是pastebin上的代码(使用依赖于C ++ 11的SquareMatrix类)。


谢谢你的帮助!我明天会看一下你的代码。 - newprint
1
@newprint:我犯了几个小错误,对你的代码没有影响,但使得SquareMatrix类不适合一般使用。我会修复它们。 - Omnifarious
简化版- void madd(int N, int Xpitch, const double X[], int Ypitch, const double Y[], int Spitch, double S[]) { for (int i = 0; i < N; i++) for (int j = 0; j < N; j++) S[iSpitch + j] = X[iXpitch + j] + Y[i*Ypitch + j]; } - newprint
1
@newprint:我不喜欢那个版本,因为你必须记住在每次访问矩阵时都要使用乘法。但是它非常像C语言,并且完全不使用C++特性。:-) 我的版本(带有内联函数)使编译器能够进行一些有趣的假设,并进行一些非常好的优化,同时允许实际的乘法算法仍然以清晰地使用多维数组访问的方式呈现出来。 - Omnifarious

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接