使用 NEON 指令集转置 8x8 浮点矩阵

4

我有一个程序需要对8x8的float32矩阵执行转置操作多次。我想使用NEON SIMD指令集来进行转置。我知道这个数组将始终包含8x8的float元素。以下是基准非内嵌解决方案:

void transpose(float *matrix, float *matrixT) {
    for (int i = 0; i < 8; i++) {
        for (int j = 0; j < 8; j++) {
            matrixT[i*8+j] = matrix[j*8+i];
        }
    }
}

我还创建了一种内在的解决方案,它将8x8矩阵的每个4x4象限转置,并交换第二个和第三个象限的位置。这个解决方案看起来像这样:

void transpose_4x4(float *matrix, float *matrixT, int store_index) {
    float32x4_t r0, r1, r2, r3, c0, c1, c2, c3;
    r0 = vld1q_f32(matrix);
    r1 = vld1q_f32(matrix + 8);
    r2 = vld1q_f32(matrix + 16);
    r3 = vld1q_f32(matrix + 24);

    c0 = vzip1q_f32(r0, r1);
    c1 = vzip2q_f32(r0, r1);
    c2 = vzip1q_f32(r2, r3);
    c3 = vzip2q_f32(r2, r3);

    r0 = vcombine_f32(vget_low_f32(c0), vget_low_f32(c2));
    r1 = vcombine_f32(vget_high_f32(c0), vget_high_f32(c2));
    r2 = vcombine_f32(vget_low_f32(c1), vget_low_f32(c3));
    r3 = vcombine_f32(vget_high_f32(c1), vget_high_f32(c3));

    vst1q_f32(matrixT + store_index, r0);
    vst1q_f32(matrixT + store_index + 8, r1);
    vst1q_f32(matrixT + store_index + 16, r2);
    vst1q_f32(matrixT + store_index + 24, r3);
}

void transpose(float *matrix, float *matrixT) {
    // Transpose top-left 4x4 quadrant and store the result in the top-left 4x4 quadrant
    transpose_4x4(matrix, matrixT, 0);

    // Transpose top-right 4x4 quadrant and store the result in the bottom-left 4x4 quadrant
    transpose_4x4(matrix + 4, matrixT, 32);

    // Transpose bottom-left 4x4 quadrant and store the result in the top-right 4x4 quadrant
    transpose_4x4(matrix + 32, matrixT, 4);

    // Transpose bottom-right 4x4 quadrant and store the result in the bottom-right 4x4 quadrant
    transpose_4x4(matrix + 36, matrixT, 36);
}

然而,这种解决方案的性能比基线非内在解决方案慢。我很难看出是否有更快的解决方案可以转置我的8x8矩阵。任何帮助都将不胜感激!

编辑:两种解决方案都使用-O1标志编译。


2
可能最好包含一些额外的细节,比如你正在构建哪个ARM ISA,以及你正在使用哪些编译器选项。 - Thomas Jager
3
两种解决方案都使用了 -O1 标志进行编译。为什么不选择完全优化呢?请至少使用 -O2 来启用函数内联!最好是使用 -O3 -mcpu=cortex-a76 或与目标机器相匹配的选项。 - Peter Cordes
1
性能警报:没有单一(免费)的工具链可以像 vtrn vzipvuzp 这样进行排列组合,而不会通过在 ARM32 中膨胀二进制文件来添加无用的 vmovvorr。你最好用汇编语言编写它。 - Jake 'Alquimista' LEE
你在 vzip 中选择了错误的内置函数。由于这是一个“转置”问题,因此你应该专门使用 vtrn - Jake 'Alquimista' LEE
2个回答

5

首先,你不应该期望一开始就有巨大的性能提升:

  • 实际上没有计算;
  • 你正在处理32位数据,因此带宽约束不大。

总的来说,只需通过矢量化节省一点带宽即可。

至于4x4矩阵转置,你甚至不需要单独的函数,只需要一个宏即可:

#define TRANSPOSE4x4(pSrc,pDst) vst1q_f32_x4(pDst,vld4q_f32(pSrc))

如果你使用 vld4 加载数据,NEON 能够在加载时进行 4x4 转置,因此这个工作不需要手动实现。

但是你应该思考一下,在实际计算之前,是否采用转置所有矩阵的方法是正确的。如果 4x4 转置几乎没有成本,那么这一步可能会成为纯粹的计算和带宽浪费。优化不应仅限于最后一步,而应从设计阶段就考虑到。

然而,8x8 转置则有所不同:

void transpose8x8(float *pDst, float *pSrc)
    {
        float32x4_t row0a, row0b, row1a, row1b, row2a, row2b, row3a, row3b, row4a, row4b, row5a, row5b, row6a, row6b, row7a, row7b;
        float32x4_t r0a, r0b, r1a, r1b, r2a, r2b, r3a, r3b, r4a, r4b, r5a, r5b, r6a, r6b, r7a, r7b;

        row0a = vld1q_f32(pSrc);
        pSrc += 4;
        row0b = vld1q_f32(pSrc);
        pSrc += 4;
        row1a = vld1q_f32(pSrc);
        pSrc += 4;
        row1b = vld1q_f32(pSrc);
        pSrc += 4;
        row2a = vld1q_f32(pSrc);
        pSrc += 4;
        row2b = vld1q_f32(pSrc);
        pSrc += 4;
        row3a = vld1q_f32(pSrc);
        pSrc += 4;
        row3b = vld1q_f32(pSrc);
        pSrc += 4;
        row4a = vld1q_f32(pSrc);
        pSrc += 4;
        row4b = vld1q_f32(pSrc);
        pSrc += 4;
        row5a = vld1q_f32(pSrc);
        pSrc += 4;
        row5b = vld1q_f32(pSrc);
        pSrc += 4;
        row6a = vld1q_f32(pSrc);
        pSrc += 4;
        row6b = vld1q_f32(pSrc);
        pSrc += 4;
        row7a = vld1q_f32(pSrc);
        pSrc += 4;
        row7b = vld1q_f32(pSrc);

        r0a = vtrn1q_f32(row0a, row1a);
        r0b = vtrn1q_f32(row0b, row1b);
        r1a = vtrn2q_f32(row0a, row1a);
        r1b = vtrn2q_f32(row0b, row1b);
        r2a = vtrn1q_f32(row2a, row3a);
        r2b = vtrn1q_f32(row2b, row3b);
        r3a = vtrn2q_f32(row2a, row3a);
        r3b = vtrn2q_f32(row2b, row3b);
        r4a = vtrn1q_f32(row4a, row5a);
        r4b = vtrn1q_f32(row4b, row5b);
        r5a = vtrn2q_f32(row4a, row5a);
        r5b = vtrn2q_f32(row4b, row5b);
        r6a = vtrn1q_f32(row6a, row7a);
        r6b = vtrn1q_f32(row6b, row7b);
        r7a = vtrn2q_f32(row6a, row7a);
        r7b = vtrn2q_f32(row6b, row7b);

        row0a = vtrn1q_f64(row0a, row2a);
        row0b = vtrn1q_f64(row0b, row2b);
        row1a = vtrn1q_f64(row1a, row3a);
        row1b = vtrn1q_f64(row1b, row3b);
        row2a = vtrn2q_f64(row0a, row2a);
        row2b = vtrn2q_f64(row0b, row2b);
        row3a = vtrn2q_f64(row1a, row3a);
        row3b = vtrn2q_f64(row1b, row3b);
        row4a = vtrn1q_f64(row4a, row6a);
        row4b = vtrn1q_f64(row4b, row6b);
        row5a = vtrn1q_f64(row5a, row7a);
        row5b = vtrn1q_f64(row5b, row7b);
        row6a = vtrn2q_f64(row4a, row6a);
        row6b = vtrn2q_f64(row4b, row6b);
        row7a = vtrn2q_f64(row5a, row7a);
        row7b = vtrn2q_f64(row5b, row7b);

        vst1q_f32(pDst, row0a);
        pDst += 4;
        vst1q_f32(pDst, row4a);
        pDst += 4;
        vst1q_f32(pDst, row1a);
        pDst += 4;
        vst1q_f32(pDst, row5a);
        pDst += 4;
        vst1q_f32(pDst, row2a);
        pDst += 4;
        vst1q_f32(pDst, row6a);
        pDst += 4;
        vst1q_f32(pDst, row3a);
        pDst += 4;
        vst1q_f32(pDst, row7a);
        pDst += 4;
        vst1q_f32(pDst, row0b);
        pDst += 4;
        vst1q_f32(pDst, row4b);
        pDst += 4;
        vst1q_f32(pDst, row1b);
        pDst += 4;
        vst1q_f32(pDst, row5b);
        pDst += 4;
        vst1q_f32(pDst, row2b);
        pDst += 4;
        vst1q_f32(pDst, row6b);
        pDst += 4;
        vst1q_f32(pDst, row3b);
        pDst += 4;
        vst1q_f32(pDst, row7b);

    }

归根结底: 16次加载 + 32次传输 + 16次存储 vs 64次加载 + 64次存储

现在我们可以清楚地看到,它确实不值得。上面的NEON例程可能会更快一些,但我怀疑最终不会有太大区别。

不,你无法进一步优化它。没有人能够优化。只需确保指针是64字节对齐的,测试它,然后自行决定。

ld1     {v0.4s-v3.4s}, [x1], #64
ld1     {v4.4s-v7.4s}, [x1], #64
ld1     {v16.4s-v19.4s}, [x1], #64
ld1     {v20.4s-v23.4s}, [x1]

trn1    v24.4s, v0.4s, v2.4s    // row0
trn1    v25.4s, v1.4s, v3.4s
trn2    v26.4s, v0.4s, v2.4s    // row1
trn2    v27.4s, v1.4s, v3.4s
trn1    v28.4s, v4.4s, v6.4s    // row2
trn1    v29.4s, v5.4s, v7.4s
trn2    v30.4s, v4.4s, v6.4s    // row3
trn2    v31.4s, v5.4s, v7.4s
trn1    v0.4s, v16.4s, v18.4s   // row4
trn1    v1.4s, v17.4s, v19.4s
trn2    v2.4s, v16.4s, v18.4s   // row5
trn2    v3.4s, v17.4s, v19.4s
trn1    v4.4s, v20.4s, v22.4s   // row6
trn1    v5.4s, v21.4s, v23.4s
trn2    v6.4s, v20.4s, v22.4s   // row7
trn2    v7.4s, v21.4s, v23.4s

trn1    v16.2d, v24.2d, v28.2d  // row0a
trn1    v17.2d, v0.2d, v4.2d    // row0b
trn1    v18.2d, v26.2d, v30.2d  // row1a
trn1    v19.2d, v2.2d, v6.2d    // row1b
trn2    v20.2d, v24.2d, v28.2d  // row2a
trn2    v21.2d, v0.2d, v4.2d    // row2b
trn2    v22.2d, v26.2d, v30.2d  // row3a
trn2    v23.2d, v2.2d, v6.2d    // row3b

st1     {v16.4s-v19.4s}, [x0], #64
st1     {v20.4s-v23.4s}, [x0], #64

trn1    v16.2d, v25.2d, v29.2d  // row4a
trn1    v17.2d, v1.2d, v5.2d    // row4b
trn1    v18.2d, v27.2d, v31.2d  // row5a
trn1    v19.2d, v3.2d, v7.2d    // row5b
trn2    v20.2d, v25.2d, v29.2d  // row4a
trn2    v21.2d, v1.2d, v5.2d    // row4b
trn2    v22.2d, v27.2d, v31.2d  // row5a
trn2    v23.2d, v3.2d, v7.2d    // row5b

st1     {v16.4s-v19.4s}, [x0], #64
st1     {v20.4s-v23.4s}, [x0]

ret

上面是手工优化汇编版本,很可能更短(尽可能短),但不一定比下面的纯C版本更快。

下面是我会采用的纯C版本:

void transpose8x8(float *pDst, float *pSrc)
{
    uint32_t i = 8;
    do {
        pDst[0] = *pSrc++;
        pDst[8] = *pSrc++;
        pDst[16] = *pSrc++;
        pDst[24] = *pSrc++;
        pDst[32] = *pSrc++;
        pDst[40] = *pSrc++;
        pDst[48] = *pSrc++;
        pDst[56] = *pSrc++;
        pDst++;            
    } while (--i);
}

或者
void transpose8x8(float *pDst, float *pSrc)
{
    uint32_t i = 8;
    do {
        *pDst++ = pSrc[0];
        *pDst++ = pSrc[8];
        *pDst++ = pSrc[16];
        *pDst++ = pSrc[24];
        *pDst++ = pSrc[32];
        *pDst++ = pSrc[40];
        *pDst++ = pSrc[48];
        *pDst++ = pSrc[56];
        pSrc++;
    } while (--i);
}

PS:如果你将pDstpSrc声明为uint32_t *,可能会在性能/功耗方面带来一些收益,因为编译器肯定会生成最具有各种寻址模式的纯整数机器代码,并且只使用w寄存器而不是s寄存器。只需将float *转换为uint32_t *

PS2:Clang已经利用w寄存器而不是s寄存器,而GCC还是老样子.... GNU迷们什么时候才会承认GCC对ARM来说是一个极其糟糕的选择?
godbolt

PS3:以下是非neon版本的汇编代码(零延迟),因为我对上面的Clang和GCC非常失望(甚至震惊):

    .arch armv8-a
    .global transpose8x8
    .text

.balign 64
.func
transpose8x8:
    mov     w10, #8
    sub     x0, x0, #8
.balign 16
1:
    ldr     w2, [x1, #0]
    ldr     w3, [x1, #32]
    ldr     w4, [x1, #64]
    ldr     w5, [x1, #96]
    ldr     w6, [x1, #128]
    ldr     w7, [x1, #160]
    ldr     w8, [x1, #192]
    ldr     w9, [x1, #224]
    subs    w10, w10, #1
    stp     w2, w3, [x0, #8]
    add     x1, x1, #4
    stp     w4, w5, [x0, #16]
    stp     w6, w7, [x0, #24]
    stp     w8, w9, [x0, #32]!
    b.ne    1b
.balign 16
    ret
.endfunc
.end

如果您仍然坚持进行纯8x8转置,那么这可能是您得到的最好版本。它可能比neon汇编版本慢一些,但耗电量要少得多。


1

可以优化其他回答提出的8x8 neon代码; 8x8转置不仅可以被视为[A B;C D]' == [A' C'; B' D']的递归版本,还可以作为zip或unzip的重复应用。

  a b c d  
  e f g h 
  i j k l
  m n o p  == a b c d e f g h i j k l m n o p

  zip(first_half, last_half) ==
  zip(...) == a i b j c k d l e m f n g o h p
  zip(...) == a e i m b f j n c g k o d h l p == transpose

对于8x8矩阵,我们需要应用此算法3次,并通过vld4读取数据,其中两次已经完成。

   float32x4x4_t d0 = vld4q_f32(input);
   float32x4x4_t d1 = vld4q_f32(input + 16);
   float32x4x4_t d2 = vld4q_f32(input + 32);
   float32x4x4_t d3 = vld4q_f32(input + 48);
   float32x4x4_t e0 = {
       vzipq_f32(d0.val[0], d2.val[0]).val[0],
       vzipq_f32(d0.val[1], d2.val[1]).val[0],
       vzipq_f32(d0.val[2], d2.val[2]).val[0],
       vzipq_f32(d0.val[3], d2.val[3]).val[0]
   };
   float32x4x4_t e1 = {
       vzipq_f32(d1.val[0], d3.val[0]).val[0],
       vzipq_f32(d1.val[1], d3.val[1]).val[0],
       vzipq_f32(d1.val[2], d3.val[2]).val[0],
       vzipq_f32(d1.val[3], d3.val[3]).val[0]
   };
   float32x4x4_t e2 = {
       vzipq_f32(d0.val[0], d2.val[0]).val[1],
       vzipq_f32(d0.val[1], d2.val[1]).val[1],
       vzipq_f32(d0.val[2], d2.val[2]).val[1],
       vzipq_f32(d0.val[3], d2.val[3]).val[1]
   };
   float32x4x4_t e3 = {
       vzipq_f32(d1.val[0], d3.val[0]).val[1],
       vzipq_f32(d1.val[1], d3.val[1]).val[1],
       vzipq_f32(d1.val[2], d3.val[2]).val[1],
       vzipq_f32(d1.val[3], d3.val[3]).val[1]
   };
   vst1q_f32_x4(output, e0);
   vst1q_f32_x4(output + 16, e1);
   vst1q_f32_x4(output + 32, e2);
   vst1q_f32_x4(output + 48, e3);

使用vld1q_f32_x4开始,然后使用uzpq,最后使用vst4q_f32,即可执行转置操作。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接