AVX中的矩阵向量乘法并不比SSE快。

8

我正在使用以下内容编写SSE和AVX的矩阵向量乘法:

for(size_t i=0;i<M;i++) {
    size_t index = i*N;
    __m128 a, x, r1;
    __m128 sum = _mm_setzero_ps();
    for(size_t j=0;j<N;j+=4,index+=4) {
         a = _mm_load_ps(&A[index]);
         x = _mm_load_ps(&X[j]);
         r1 = _mm_mul_ps(a,x);
         sum = _mm_add_ps(r1,sum);
    }
    sum = _mm_hadd_ps(sum,sum);
    sum = _mm_hadd_ps(sum,sum);
    _mm_store_ss(&C[i],sum);
}

我在AVX中使用了类似的方法,但是由于AVX没有与_mm_store_ss()等效的指令,在最后我使用了:

_mm_store_ss(&C[i],_mm256_castps256_ps128(sum));

SSE代码使我的程序在串行代码的基础上加速了3.7倍。然而,AVX代码只使我的程序在串行代码的基础上加速了4.3倍。

我知道使用SSE和AVX可能会引起问题,但我使用g++编译时使用了“-mavx”标志,应该可以去除SSE操作码。

我也可以使用:_mm256_storeu_ps(&C[i],sum)来做同样的事情,但加速效果是相同的。

有什么其他的见解可以改善性能吗?这可能与performance_memory_bound有关,尽管我没有清楚地理解那个线程上的答案。

此外,即使包括“immintrin.h”头文件,我也无法使用_mm_fmadd_ps()指令。我已经启用了FMA和AVX。


2
可能是CPU在等待内存IO时只是空闲着。这意味着它实际上计算得更快,但是同样会因为等待更长时间的下一块数据而被卡住。 - Marc Claesen
_mm_store_ss(&C[i],_mm256_castps256_ps128(sum)); 是AVX中的等效指令。SSE指令只操作256位AVX寄存器的低128位。强制转换只是为了让编译器满意,不使用任何指令。 - Z boson
2
你应该至少尝试展开你的循环。 - Z boson
@ChristianRau:是的,我还需要一个额外的_mm256_hadd_ps()函数。 - user1715122
1
我不仅在进行矩阵向量的计算,还有矩阵矩阵。我采用了很多优化措施,包括循环展开、循环分块、AVX和OpenMP等。实际上要达到峰值浮点运算50%以上是相当困难的。最终,我的峰值浮点运算速度达到了70%,仍然比MKL低但比Eigen快。 - Z boson
显示剩余6条评论
3个回答

5
我建议您重新考虑您的算法。请参阅讨论 Efficient 4x4 matrix vector multiplication with SSE: horizontal add and dot product - what's the point?。 您一次做一个长点积,并在每次迭代中使用_mm_hadd_ps。相反,您应该使用SSE(AVX为八个)同时进行四个点乘积,并仅使用垂直运算符。
您需要加法,乘法和广播。这都可以用SSE来完成_mm_add_ps_mm_mul_ps_mm_shuffle_ps(用于广播)。
如果您已经有矩阵的转置,那么这非常简单。
但无论您是否具有转置,都需要使代码更具缓存友好性。为解决此问题,我建议将矩阵分块。请参见本讨论 What is the fastest way to transpose a matrix in C++? 以了解如何进行矩阵分块。
我建议您在尝试SSE/AVX之前先正确地使用循环分块。我在矩阵乘法中取得的最大提升并不是来自SIMD或线程,而是来自循环分块。我认为如果您正确使用缓存,您的AVX代码将表现得更加线性比SSE。

3
考虑下面的代码。我不熟悉INTEL版本,但这比DirectX中的XMMatrixMultiply更快。它不是关于每个指令所执行的数学运算量有多少,而是要减少指令计数(只要您使用快速指令,此实现就可做到)。
// Perform a 4x4 matrix multiply by a 4x4 matrix 
// Be sure to run in 64 bit mode and set right flags
// Properties, C/C++, Enable Enhanced Instruction, /arch:AVX 
// Having MATRIX on a 32 byte bundry does help performance
struct MATRIX {
    union {
        float  f[4][4];
        __m128 m[4];
        __m256 n[2];
    };
}; MATRIX myMultiply(MATRIX M1, MATRIX M2) {
    MATRIX mResult;
    __m256 a0, a1, b0, b1;
    __m256 c0, c1, c2, c3, c4, c5, c6, c7;
    __m256 t0, t1, u0, u1;

    t0 = M1.n[0];                                                   // t0 = a00, a01, a02, a03, a10, a11, a12, a13
    t1 = M1.n[1];                                                   // t1 = a20, a21, a22, a23, a30, a31, a32, a33
    u0 = M2.n[0];                                                   // u0 = b00, b01, b02, b03, b10, b11, b12, b13
    u1 = M2.n[1];                                                   // u1 = b20, b21, b22, b23, b30, b31, b32, b33

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(0, 0, 0, 0));        // a0 = a00, a00, a00, a00, a10, a10, a10, a10
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(0, 0, 0, 0));        // a1 = a20, a20, a20, a20, a30, a30, a30, a30
    b0 = _mm256_permute2f128_ps(u0, u0, 0x00);                      // b0 = b00, b01, b02, b03, b00, b01, b02, b03  
    c0 = _mm256_mul_ps(a0, b0);                                     // c0 = a00*b00  a00*b01  a00*b02  a00*b03  a10*b00  a10*b01  a10*b02  a10*b03
    c1 = _mm256_mul_ps(a1, b0);                                     // c1 = a20*b00  a20*b01  a20*b02  a20*b03  a30*b00  a30*b01  a30*b02  a30*b03

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(1, 1, 1, 1));        // a0 = a01, a01, a01, a01, a11, a11, a11, a11
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(1, 1, 1, 1));        // a1 = a21, a21, a21, a21, a31, a31, a31, a31
    b0 = _mm256_permute2f128_ps(u0, u0, 0x11);                      // b0 = b10, b11, b12, b13, b10, b11, b12, b13
    c2 = _mm256_mul_ps(a0, b0);                                     // c2 = a01*b10  a01*b11  a01*b12  a01*b13  a11*b10  a11*b11  a11*b12  a11*b13
    c3 = _mm256_mul_ps(a1, b0);                                     // c3 = a21*b10  a21*b11  a21*b12  a21*b13  a31*b10  a31*b11  a31*b12  a31*b13

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(2, 2, 2, 2));        // a0 = a02, a02, a02, a02, a12, a12, a12, a12
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(2, 2, 2, 2));        // a1 = a22, a22, a22, a22, a32, a32, a32, a32
    b1 = _mm256_permute2f128_ps(u1, u1, 0x00);                      // b0 = b20, b21, b22, b23, b20, b21, b22, b23
    c4 = _mm256_mul_ps(a0, b1);                                     // c4 = a02*b20  a02*b21  a02*b22  a02*b23  a12*b20  a12*b21  a12*b22  a12*b23
    c5 = _mm256_mul_ps(a1, b1);                                     // c5 = a22*b20  a22*b21  a22*b22  a22*b23  a32*b20  a32*b21  a32*b22  a32*b23

    a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(3, 3, 3, 3));        // a0 = a03, a03, a03, a03, a13, a13, a13, a13
    a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(3, 3, 3, 3));        // a1 = a23, a23, a23, a23, a33, a33, a33, a33
    b1 = _mm256_permute2f128_ps(u1, u1, 0x11);                      // b0 = b30, b31, b32, b33, b30, b31, b32, b33
    c6 = _mm256_mul_ps(a0, b1);                                     // c6 = a03*b30  a03*b31  a03*b32  a03*b33  a13*b30  a13*b31  a13*b32  a13*b33
    c7 = _mm256_mul_ps(a1, b1);                                     // c7 = a23*b30  a23*b31  a23*b32  a23*b33  a33*b30  a33*b31  a33*b32  a33*b33

    c0 = _mm256_add_ps(c0, c2);                                     // c0 = c0 + c2 (two terms, first two rows)
    c4 = _mm256_add_ps(c4, c6);                                     // c4 = c4 + c6 (the other two terms, first two rows)
    c1 = _mm256_add_ps(c1, c3);                                     // c1 = c1 + c3 (two terms, second two rows)
    c5 = _mm256_add_ps(c5, c7);                                     // c5 = c5 + c7 (the other two terms, second two rose)

    // Finally complete addition of all four terms and return the results
    mResult.n[0] = _mm256_add_ps(c0, c4);       // n0 = a00*b00+a01*b10+a02*b20+a03*b30  a00*b01+a01*b11+a02*b21+a03*b31  a00*b02+a01*b12+a02*b22+a03*b32  a00*b03+a01*b13+a02*b23+a03*b33
                                                //      a10*b00+a11*b10+a12*b20+a13*b30  a10*b01+a11*b11+a12*b21+a13*b31  a10*b02+a11*b12+a12*b22+a13*b32  a10*b03+a11*b13+a12*b23+a13*b33
    mResult.n[1] = _mm256_add_ps(c1, c5);       // n1 = a20*b00+a21*b10+a22*b20+a23*b30  a20*b01+a21*b11+a22*b21+a23*b31  a20*b02+a21*b12+a22*b22+a23*b32  a20*b03+a21*b13+a22*b23+a23*b33
                                                //      a30*b00+a31*b10+a32*b20+a33*b30  a30*b01+a31*b11+a32*b21+a33*b31  a30*b02+a31*b12+a32*b22+a33*b32  a30*b03+a31*b13+a32*b23+a33*b33
    return mResult;
}

1

正如有人提议的那样,添加-funroll-loops。

奇怪的是这不是默认设置。

对于任何浮点指针的定义,请使用__restrict。 对于常量数组引用,请使用const。 我不知道编译器是否足够聪明,能够识别循环内部的三个中间值不需要从迭代到迭代保持活动状态。 我会删除这3个变量或者至少将它们设为循环内部的局部变量(a、x、r1)。索引可以被声明,在其中j也被声明以使其更局部化。 确保M和N被声明为const,并且如果它们的值是编译时常量,则让编译器看到它们。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接