如何优化SIMD转置函数(8x4 => 4x8)?

4
我需要使用AVX来优化8x4和4x8浮点矩阵的转置。我使用Agner Fog的向量类库vector class library
蓝绿色任务 - 构建BVH并求和最小值-最大值。在每个循环的最后阶段使用转置(他们还通过多线程进行了优化,但任务可能非常多)。
代码现在看起来像:
void transpose(register Vec4f (&fin)[8], register Vec8f (&mat)[4]) {
    for (int i = 0;i < 8;i++) {
        fin[i] = lookup<28>(Vec4i(0, 8, 16, 24) + i, (float *)mat);
    }
}

需要优化的变体。如何为SIMD优化此函数?


我最近使用VectorClass编写了自己的转置变体(4x8和8x4)。版本为1.0。

void transpose(register Vec4f(&fin)[8], register Vec8f(&mat)[4]) {
    register Vec8f a00 = blend8f<0, 8, 1, 9, 2, 10, 3, 11>(mat[0], mat[1]);
    register Vec8f a10 = blend8f<0, 8, 1, 9, 2, 10, 3, 11>(mat[2], mat[3]);
    register Vec8f a01 = blend8f<4, 12, 5, 13, 6, 14, 7, 15>(mat[0], mat[1]);
    register Vec8f a11 = blend8f<4, 12, 5, 13, 6, 14, 7, 15>(mat[2], mat[3]);

    register Vec8f v0_1 = blend8f<0, 1, 8, 9, 2, 3, 10, 11>(a00, a10);
    register Vec8f v2_3 = blend8f<4, 5, 12, 13, 6, 7, 14, 15>(a00, a10);
    register Vec8f v4_5 = blend8f<0, 1, 8, 9, 2, 3, 10, 11>(a01, a11);
    register Vec8f v6_7 = blend8f<4, 5, 12, 13, 6, 7, 14, 15>(a01, a11);

    fin[0] = v0_1.get_low();
    fin[1] = v0_1.get_high();
    fin[2] = v2_3.get_low();
    fin[3] = v2_3.get_high();
    fin[4] = v4_5.get_low();
    fin[5] = v4_5.get_high();
    fin[6] = v6_7.get_low();
    fin[7] = v6_7.get_high();
}

void transpose(register Vec8f(&fin)[4], register Vec4f(&mat)[8]) {
    register Vec8f a0_1 = Vec8f(mat[0], mat[1]);
    register Vec8f a2_3 = Vec8f(mat[2], mat[3]);
    register Vec8f a4_5 = Vec8f(mat[4], mat[5]);
    register Vec8f a6_7 = Vec8f(mat[6], mat[7]);

    register Vec8f a00 = blend8f<0, 4, 8 , 12, 1, 5, 9 , 13>(a0_1, a2_3);
    register Vec8f a10 = blend8f<0, 4, 8 , 12, 1, 5, 9 , 13>(a4_5, a6_7);
    register Vec8f a01 = blend8f<2, 6, 10, 14, 3, 7, 11, 15>(a0_1, a2_3);
    register Vec8f a11 = blend8f<2, 6, 10, 14, 3, 7, 11, 15>(a4_5, a6_7);

    fin[0] = blend8f<0, 1, 2, 3, 8, 9, 10, 11>(a00, a10);
    fin[1] = blend8f<4, 5, 6, 7, 12, 13, 14, 15>(a00, a10);
    fin[2] = blend8f<0, 1, 2, 3, 8, 9, 10, 11>(a01, a11);
    fin[3] = blend8f<4, 5, 6, 7, 12, 13, 14, 15>(a01, a11);
}

需要版本2.0。
2个回答

4
我没有使用过vectorclass库,但是从快速查看lookup模板函数的源代码中,似乎你正在做一些非常低效的事情。
我提出了一个简单而高效的解决方案,使用SSE/AVX内部函数。我不知道如何完全使用vectorclass库来编码它。然而,您可以使用转换运算符从类Vec4f和Vec8f中提取原始数据作为__m128和__m256,并使用适当的构造函数将原始结果转换回向量类。
在纯SSE内部函数中,头文件xmmintrin.h中有一个名为_MM_TRANSPOSE4_PS的宏。它将浮点数的4x4矩阵转置,每行在单独的128位寄存器中。如果您只有SSE(即没有AVX),那么您只需调用此宏两次即可完成。以下是代码:
#define _MM_TRANSPOSE4_PS(row0, row1, row2, row3) {    \
  __m128 tmp3, tmp2, tmp1, tmp0;                      \
  tmp0 = _mm_shuffle_ps(row0, row1, 0x44);            \
  tmp2 = _mm_shuffle_ps(row0, row1, 0xEE);            \
  tmp1 = _mm_shuffle_ps(row2, row3, 0x44);            \
  tmp3 = _mm_shuffle_ps(row2, row3, 0xEE);            \
  row0 = _mm_shuffle_ps(tmp0, tmp1, 0x88);            \
  row1 = _mm_shuffle_ps(tmp0, tmp1, 0xDD);            \
  row2 = _mm_shuffle_ps(tmp2, tmp3, 0x88);            \
  row3 = _mm_shuffle_ps(tmp2, tmp3, 0xDD);            \
}

在AVX中,具有256位操作数的指令通常只对操作数(称为lane)的两个半部分执行SSE等效操作。而intrinsic“_mm256_shuffle_ps”也不例外:它简单地将两个128位lane混洗成其_mm等效项所做的方式。这意味着,如果在宏中将_mm前缀更改为_mm256前缀,则会转置两个4x4矩阵:一个位于四个256位寄存器的低位lane中,另一个位于四个256位寄存器的高位lane中。我们只需将结果的256位寄存器分为两半并适当排序即可。
以下是生成的代码。我已经检查过它可以正常工作。似乎只有12条指令,所以我认为它会很快。
void Transpose4x8(__m128 dst[8], __m256 src[4]) {
  __m256 row0 = src[0], row1 = src[1], row2 = src[2], row3 = src[3];
  __m256 tmp3, tmp2, tmp1, tmp0;
  tmp0 = _mm256_shuffle_ps(row0, row1, 0x44);
  tmp2 = _mm256_shuffle_ps(row0, row1, 0xEE);
  tmp1 = _mm256_shuffle_ps(row2, row3, 0x44);
  tmp3 = _mm256_shuffle_ps(row2, row3, 0xEE);
  row0 = _mm256_shuffle_ps(tmp0, tmp1, 0x88);
  row1 = _mm256_shuffle_ps(tmp0, tmp1, 0xDD);
  row2 = _mm256_shuffle_ps(tmp2, tmp3, 0x88);
  row3 = _mm256_shuffle_ps(tmp2, tmp3, 0xDD);
  dst[0] = _mm256_castps256_ps128(row0);
  dst[1] = _mm256_castps256_ps128(row1);
  dst[2] = _mm256_castps256_ps128(row2);
  dst[3] = _mm256_castps256_ps128(row3);
  dst[4] = _mm256_extractf128_ps(row0, 1);
  dst[5] = _mm256_extractf128_ps(row1, 1);
  dst[6] = _mm256_extractf128_ps(row2, 1);
  dst[7] = _mm256_extractf128_ps(row3, 1);
}

更新 反转置的方式与正常置换基本相同,只是有些事情顺序相反。以下是代码:

void Transpose8x4(__m256 dst[4], __m128 src[8]) {
  __m256 row0 = _mm256_setr_m128(src[0], src[4]);
  __m256 row1 = _mm256_setr_m128(src[1], src[5]);
  __m256 row2 = _mm256_setr_m128(src[2], src[6]);
  __m256 row3 = _mm256_setr_m128(src[3], src[7]);
  __m256 tmp3, tmp2, tmp1, tmp0;
  tmp0 = _mm256_shuffle_ps(row0, row1, 0x44);
  tmp2 = _mm256_shuffle_ps(row0, row1, 0xEE);
  tmp1 = _mm256_shuffle_ps(row2, row3, 0x44);
  tmp3 = _mm256_shuffle_ps(row2, row3, 0xEE);
  row0 = _mm256_shuffle_ps(tmp0, tmp1, 0x88);
  row1 = _mm256_shuffle_ps(tmp0, tmp1, 0xDD);
  row2 = _mm256_shuffle_ps(tmp2, tmp3, 0x88);
  row3 = _mm256_shuffle_ps(tmp2, tmp3, 0xDD);
  dst[0] = row0; dst[1] = row1; dst[2] = row2; dst[3] = row3;
}

抱歉,但需要进行反转置。很遗憾我在主要问题中没有提到。 - user2454034
@user2454034:添加了反向转置,实际上它与之前的操作非常相似。 - stgatilov

1

向量类库(VCL)使用模板元编程来确定最佳的置换和混合内部函数。然而,当涉及到置换和混合时,您通常仍需要了解硬件的限制以获得最佳结果。

我将Stgatilov先生已经很好的答案转换为使用VCL,并且它产生理想的汇编代码(八个洗牌)。以下是该函数:

void tran8x4_AVX(float *a, float *b) {
    Vec8f tmp0, tmp1, tmp2, tmp3;
    Vec8f row0, row1, row2, row3;

    row0 = Vec8f().load(&a[8*0]);
    row1 = Vec8f().load(&a[8*1]);
    row2 = Vec8f().load(&a[8*2]);
    row3 = Vec8f().load(&a[8*3]);    

    tmp0 = blend8f<0, 1,  8, 9,  4, 5, 12, 13>(row0, row1);
    tmp2 = blend8f<2, 3, 10, 11, 6, 7, 14, 15>(row0, row1);
    tmp1 = blend8f<0, 1,  8, 9,  4, 5, 12, 13>(row2, row3);
    tmp3 = blend8f<2, 3, 10, 11, 6, 7, 14, 15>(row2, row3);

    row0 = blend8f<0, 2, 8, 10, 4, 6, 12, 14>(tmp0, tmp1);
    row1 = blend8f<1, 3, 9, 11, 5, 7, 13, 15>(tmp0, tmp1);
    row2 = blend8f<0, 2, 8, 10, 4, 6, 12, 14>(tmp2, tmp3);
    row3 = blend8f<1, 3, 9, 11, 5, 7, 13, 15>(tmp2, tmp3);

    row0.get_low().store(&b[  4*0]);
    row1.get_low().store(&b[  4*1]);
    row2.get_low().store(&b[  4*2]);
    row3.get_low().store(&b[  4*3]);
    row0.get_high().store(&b[ 4*4]);
    row1.get_high().store(&b[ 4*5]);
    row2.get_high().store(&b[ 4*6]);
    row3.get_high().store(&b[ 4*7]);
}

这里是汇编代码(g++ -S -O3 -mavx test.cpp

    vmovups 32(%rdi), %ymm4
    vmovups 64(%rdi), %ymm3
    vmovups (%rdi), %ymm1
    vmovups 96(%rdi), %ymm0
    vshufps $68, %ymm4, %ymm1, %ymm2
    vshufps $68, %ymm0, %ymm3, %ymm5
    vshufps $238, %ymm4, %ymm1, %ymm1
    vshufps $238, %ymm0, %ymm3, %ymm0
    vshufps $136, %ymm5, %ymm2, %ymm4
    vshufps $221, %ymm5, %ymm2, %ymm2
    vshufps $136, %ymm0, %ymm1, %ymm3
    vshufps $221, %ymm0, %ymm1, %ymm0
    vmovups %xmm4, (%rsi)
    vextractf128    $0x1, %ymm4, %xmm4
    vmovups %xmm2, 16(%rsi)
    vextractf128    $0x1, %ymm2, %xmm2
    vmovups %xmm3, 32(%rsi)
    vextractf128    $0x1, %ymm3, %xmm3
    vmovups %xmm0, 48(%rsi)
    vextractf128    $0x1, %ymm0, %xmm0
    vmovups %xmm4, 64(%rsi)
    vmovups %xmm2, 80(%rsi)
    vmovups %xmm3, 96(%rsi)
    vmovups %xmm0, 112(%rsi)
    vzeroupper
    ret
    .cfi_endproc

这是一个完整的测试。
#include <stdio.h>
#include "vectorclass.h"

void tran8x4(float *a, float *b) {
    for(int i=0; i<4; i++) {
        for(int j=0; j<8; j++) {
            b[j*4+i] = a[i*8+j];
        }
    }
}

void tran8x4_AVX(float *a, float *b) {
    Vec8f tmp0, tmp1, tmp2, tmp3;
    Vec8f row0, row1, row2, row3;

    row0 = Vec8f().load(&a[8*0]);
    row1 = Vec8f().load(&a[8*1]);
    row2 = Vec8f().load(&a[8*2]);
    row3 = Vec8f().load(&a[8*3]);


    tmp0 = blend8f<0, 1, 8, 9, 4, 5, 12, 13>(row0, row1);
    tmp2 = blend8f<2, 3, 10, 11, 6, 7, 14, 15>(row0, row1);
    tmp1 = blend8f<0, 1, 8, 9, 4, 5, 12, 13>(row2, row3);
    tmp3 = blend8f<2, 3, 10, 11, 6, 7, 14, 15>(row2, row3);

    row0 = blend8f<0, 2, 8, 10, 4, 6, 12, 14>(tmp0, tmp1);
    row1 = blend8f<1, 3, 9, 11, 5, 7, 13, 15>(tmp0, tmp1);
    row2 = blend8f<0, 2, 8, 10, 4, 6, 12, 14>(tmp2, tmp3);
    row3 = blend8f<1, 3, 9, 11, 5, 7, 13, 15>(tmp2, tmp3);

    row0.get_low().store(&b[  4*0]);
    row1.get_low().store(&b[  4*1]);
    row2.get_low().store(&b[  4*2]);
    row3.get_low().store(&b[  4*3]);
    row0.get_high().store(&b[ 4*4]);
    row1.get_high().store(&b[ 4*5]);
    row2.get_high().store(&b[ 4*6]);
    row3.get_high().store(&b[ 4*7]);

}


int main() {
    float a[32], b1[32], b2[32];
    for(int i=0; i<32; i++) a[i] = i;
    for(int i=0; i<4; i++) {
        for(int j=0; j<8; j++) {
            printf("%2.0f ", a[i*8+j]);
        } puts("");
    }
    tran8x4(a,b1);
    tran8x4_AVX(a,b2);
    puts("");
    for(int i=0; i<8; i++) {
        for(int j=0; j<4; j++) {
            printf("%2.0f ", b1[i*4+j]);
        } puts("");
    }
    puts("");
    for(int i=0; i<8; i++) {
        for(int j=0; j<4; j++) {
            printf("%2.0f ", b2[i*4+j]);
        } puts("");
    }
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接