NEON矢量化无符号字节积和求和:(a[i]-int1) * (b[i]-int2)

4
我需要改进一个循环,因为它被我的应用程序调用了数千次。我想使用Neon来完成这个任务,但我不知道从哪里开始。
假设/前提条件:
- `w` 总是 320 (16/32 的倍数) - `pa` 和 `pb` 是 16 字节对齐的 - `ma` 和 `mb` 是正数。
 int whileInstruction (const unsigned char *pa,const unsigned char *pb,int ma,int mb,int w)
{
    int sum=0;

    do {
        sum += ((*pa++)-ma)*((*pb++)-mb);

    } while(--w);


    return sum;
}

这个尝试进行向量化的方法效果不佳,也不安全(缺少清除器),但是演示了我想要做的事情:
int whileInstruction (const unsigned char *pa,const unsigned char *pb,int ma,int mb,int w)
{

    asm volatile("lsr          %2, %2, #3      \n"
                 ".loop:                       \n"
                 "# load 8 elements:             \n"
                 "vld4.8      {d0-d3}, [%1]!   \n"
                 "vld4.8      {d4-d7}, [%2]!   \n"
                 "# do the operation:     \n"
                 "vaddl.u8    q7, d0, r7       \n"
                 "vaddl.u8    q8, d1, d8       \n"
                 "vmlal.u8    q7, q7, q8       \n"
                 "# Sum the vector a save in sum (this is wrong):\n"
                 "vaddl.u8    q7, d0, r7       \n"
                 "subs        %2, %2, #1       \n" // Decrement iteration count
                 "bne         .loop            \n" // Repeat unil iteration count is not zero
                 :
                 : "r"(pa), "r"(pb), "r"(w),"r"(ma),"r"(mb),"r"(sum)
                 : "r4", "r5", "r6","r7","r8","r9"
                 );

    return sum;
}

1
什么是循环?一些限制条件会有所帮助——ma、mb 的可能范围是多少?它们总是正数吗?它们是否适合无符号字符的范围(0..255)?w呢?我们可以假设它是8或16的倍数,还是它可以取任何值? - Paul R
1
“bucle”的西班牙语到英语翻译是“loop”。OP需要加速循环操作。 - zaph
1
你是指循环吗? - INS
是的,这是一个循环,抱歉。 - Gustavo
编辑回答问题,谢谢 - Gustavo
显示剩余5条评论
2个回答

7

这里是一个简单的NEON(ARM SIMD指令集)实现。我已经测试过,并与标量代码进行了比较以确保其正常工作。请注意,为了获得最佳性能,papb都应该是16字节对齐的。

#include <arm_neon.h>

int whileInstruction_neon(const unsigned char *pa, const unsigned char *pb, int ma, int mb, int w)
{
    int sum = 0;

    const int32x4_t vma = { ma, ma, ma, ma };
    const int32x4_t vmb = { mb, mb, mb, mb };

    int32x4_t vsumll = { 0 };
    int32x4_t vsumlh = { 0 };
    int32x4_t vsumhl = { 0 };
    int32x4_t vsumhh = { 0 };
    int32x4_t vsum;

    int i;

    for (i = 0; i <= (w - 16); i += 16)
    {
        uint8x16_t va = vld1q_u8(pa);   // load vector from pa
        uint8x16_t vb = vld1q_u8(pb);   // load vector from pb

        // unpack va into 4 vectors

        int16x8_t val =  (int16x8_t)vmovl_u8(vget_low_u8(va));
        int16x8_t vah =  (int16x8_t)vmovl_u8(vget_high_u8(va));
        int32x4_t vall = vmovl_s16(vget_low_s16(val));
        int32x4_t valh = vmovl_s16(vget_high_s16(val));
        int32x4_t vahl = vmovl_s16(vget_low_s16(vah));
        int32x4_t vahh = vmovl_s16(vget_high_s16(vah));

        // subtract means

        vall = vsubq_s32(vall, vma);
        valh = vsubq_s32(valh, vma);
        vahl = vsubq_s32(vahl, vma);
        vahh = vsubq_s32(vahh, vma);

        // unpack vb into 4 vectors

        int16x8_t vbl =  (int16x8_t)vmovl_u8(vget_low_u8(vb));
        int16x8_t vbh =  (int16x8_t)vmovl_u8(vget_high_u8(vb));
        int32x4_t vbll = vmovl_s16(vget_low_s16(vbl));
        int32x4_t vblh = vmovl_s16(vget_high_s16(vbl));
        int32x4_t vbhl = vmovl_s16(vget_low_s16(vbh));
        int32x4_t vbhh = vmovl_s16(vget_high_s16(vbh));

        // subtract means

        vbll = vsubq_s32(vbll, vmb);
        vblh = vsubq_s32(vblh, vmb);
        vbhl = vsubq_s32(vbhl, vmb);
        vbhh = vsubq_s32(vbhh, vmb);

        // update 4 partial sum of products vectors

        vsumll = vmlaq_s32(vsumll, vall, vbll);
        vsumlh = vmlaq_s32(vsumlh, valh, vblh);
        vsumhl = vmlaq_s32(vsumhl, vahl, vbhl);
        vsumhh = vmlaq_s32(vsumhh, vahh, vbhh);

        pa += 16;
        pb += 16;
    }

    // sum 4 partial sum of product vectors

    vsum = vaddq_s32(vsumll, vsumlh);
    vsum = vaddq_s32(vsum, vsumhl);
    vsum = vaddq_s32(vsum, vsumhh);

    // do scalar horizontal sum across final vector

    sum = vgetq_lane_s32(vsum, 0);
    sum += vgetq_lane_s32(vsum, 1);
    sum += vgetq_lane_s32(vsum, 2);
    sum += vgetq_lane_s32(vsum, 3);

    // handle any residual non-multiple of 16 points

    for ( ; i < w; ++i)
    {
        sum +=  (*pa++ - ma) * (*pb++ - mb);
    }

    return sum;
}

编译器返回:"error: no matching function for call to '__simd128_int16_t::__simd128_int16_t(__simd128_uint16_t)'",你知道为什么吗? - Gustavo
1
你使用的编译器是什么?哪一行代码产生了错误?你是将其编译为C还是C++?顺便说一句,我使用gcc 4.5.1进行了测试(编译为C代码)。 - Paul R
好的,我把编译器改成了4.2版本,现在它可以工作了,但是比C语言版本要慢。我正在查找原因。 - Gustavo
1
请确保您使用-O3进行编译。同时,您可能正在使用某个编译器自动将标量循环向量化。FWIW,我只在使用gcc 4.5.1时获得了约1.5倍的改进-如果您需要获得大约3倍的改进,则需要使用汇编语言。 - Paul R
3
顺便说一句,如果这是某种统计计算的一部分,例如归一化交叉相关性,那么在进行微观优化之前,你可能需要退后一步并修复整体算法。 - Paul R

1

针对我的问题,Paul R提供了另一个完美的解决方案。在w等于8的情况下,通常可以使用此函数:

int whileInstruction8Valors (const unsigned char *pa,const unsigned char *pb,int ma,int mb,int w)
{

int sum=0;
//int 32 bits /4 elementos? 

const int32x4_t vma = { ma, ma, ma, ma };
const int32x4_t vmb = { mb, mb, mb, mb };

int32x4_t vsumll = { 0 };
int32x4_t vsumlh = { 0 };

int32x4_t vsum;

//char 8 bytes / 8 elementos
uint8x8_t  va2= vld1_u8(pa); // VLD1.8 {d0}, [r0]
uint8x8_t  vb2= vld1_u8(pb); // VLD1.8 {d0}, [r0]

//int 16 bytes /8 elementos
int16x8_t val =  (int16x8_t)vmovl_u8(va2);

//int 32 /4 elementos *2 
int32x4_t vall = vmovl_s16(vget_low_s16(val));
int32x4_t valh = vmovl_s16(vget_high_s16(val));

// subtract means
vall = vsubq_s32(vall, vma);
valh = vsubq_s32(valh, vma);

//int 16 bytes /8 elementos
int16x8_t vbl2 =  (int16x8_t)vmovl_u8(vb2);

//int 32 /4 elementos *2 
int32x4_t vbll = vmovl_s16(vget_low_s16(vbl2));
int32x4_t vblh = vmovl_s16(vget_high_s16(vbl2));

// subtract means

vbll = vsubq_s32(vbll, vmb);
vblh = vsubq_s32(vblh, vmb);

// update 4 partial sum of products vectors

vsumll = vmlaq_s32(vsumll, vall, vbll);
vsumlh = vmlaq_s32(vsumlh, valh, vblh);

// sum 4 partial sum of product vectors

vsum = vaddq_s32(vsumll, vsumlh);

// do scalar horizontal sum across final vector

sum = vgetq_lane_s32(vsum, 0);
sum += vgetq_lane_s32(vsum, 1);
sum += vgetq_lane_s32(vsum, 2);
sum += vgetq_lane_s32(vsum, 3);

return sum;
}

也许可以改进它。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接