如何使用SIMD计算字符出现次数

8
我被给定一个小写字符数组(最多1.5Gb)和一个字符c。我想使用AVX指令来找出字符c出现的次数。
    unsigned long long char_count_AVX2(char * vector, int size, char c){
    unsigned long long sum =0;
    int i, j;
    const int con=3;
    __m256i ans[con];
    for(i=0; i<con; i++)
        ans[i]=_mm256_setzero_si256();

    __m256i Zer=_mm256_setzero_si256();
    __m256i C=_mm256_set1_epi8(c);
    __m256i Assos=_mm256_set1_epi8(0x01);
    __m256i FF=_mm256_set1_epi8(0xFF);
    __m256i shield=_mm256_set1_epi8(0xFF);
    __m256i temp;
    int couter=0;
    for(i=0; i<size; i+=32){
        couter++;
        shield=_mm256_xor_si256(_mm256_cmpeq_epi8(ans[0], Zer), FF);
        temp=_mm256_cmpeq_epi8(C, *((__m256i*)(vector+i)));
        temp=_mm256_xor_si256(temp, FF);
        temp=_mm256_add_epi8(temp, Assos);
        ans[0]=_mm256_add_epi8(temp, ans[0]);
        for(j=1; j<con; j++){
            temp=_mm256_cmpeq_epi8(ans[j-1], Zer);
            shield=_mm256_and_si256(shield, temp);
            temp=_mm256_xor_si256(shield, FF);
            temp=_mm256_add_epi8(temp, Assos);
            ans[j]=_mm256_add_epi8(temp, ans[j]);
        }
    }
    for(j=con-1; j>=0; j--){
        sum<<=8;
        unsigned char *ptr = (unsigned char*)&(ans[j]);
        for(i=0; i<32; i++){
            sum+=*(ptr+i);
        }
    }
    return sum;
}

你的字符格式是什么?ASCII还是某种Unicode? - zx485
2
AVX1还是AVX2?你尝试过什么?提示:检查_mm256_cmpeq_epi8_mm256_sub_epi8以获取最内部循环。经过255次迭代,您需要开始将两个字节组合成一个uint16,依此类推。 - chtz
AVX2,至今我已经有一个包含4个__m256i变量的数组,并且我正在将溢出从索引0推到3。 - Adamos2468
1
_mm256_cmpeq_epi8 将在每个字节中得到一个 -1。如果您使用 _mm256_sub_epi8 从计数器中减去它,您可以直接计数到 255 或 128,即您最内层循环应该只包含这两个内嵌函数。 - chtz
2
一个核心通常无法饱和DRAM带宽,因此对于大型输入,使用多个线程可能是值得的(特别是如果您已经启动了工作线程并且可以只发送函数指针和参数)。您标记了这个[tag:parallel-processing],您是否也在询问OpenMP或其他东西? - Peter Cordes
显示剩余3条评论
3个回答

4

我有意留下一些部分,需要你自己解决(例如处理长度不是4*255*32字节的情况),但你最内层循环应该类似于以for(int i...)开头的循环:

_mm256_cmpeq_epi8会在每个字节中给出-1,你可以将其用作整数。如果你使用_mm256_sub_epi8从一个计数器中减去它,你可以直接计数到255或128。内部循环只包含这两个指令集。你必须停止并

#include <immintrin.h>
#include <stdint.h>

static inline
__m256i hsum_epu8_epu64(__m256i v) {
    return _mm256_sad_epu8(v, _mm256_setzero_si256());  // SAD against zero is a handy trick
}

static inline
uint64_t hsum_epu64_scalar(__m256i v) {
    __m128i lo = _mm256_castsi256_si128(v);
    __m128i hi = _mm256_extracti128_si256(v, 1);
    __m128i sum2x64 = _mm_add_epi64(lo, hi);   // narrow to 128

    hi = _mm_unpackhi_epi64(sum2x64, sum2x64);
    __m128i sum = _mm_add_epi64(hi, sum2x64);  // narrow to 64
    return _mm_cvtsi128_si64(sum);
}


unsigned long long char_count_AVX2(char const* vector, size_t size, char c)
{
    __m256i C=_mm256_set1_epi8(c);

    // todo: count elements and increment `vector` until it is aligned to 256bits (=32 bytes)
    __m256i const * simd_vector = (__m256i const *) vector;
     // *simd_vector is an alignment-required load, unlike _mm256_loadu_si256()

    __m256i sum64 = _mm256_setzero_si256();
    size_t unrolled_size_limit = size - 4*255*32 + 1;
    for(size_t k=0; k<unrolled_size_limit ; k+=4*255*32) // outer loop: TODO
    {
        __m256i counter[4]; // multiple counter registers to hide latencies
        for(int j=0; j<4; j++)
            counter[j]=_mm256_setzero_si256();
        // inner loop: make sure that you don't go beyond the data you can read
        for(int i=0; i<255; ++i)
        {   // or limit this inner loop to ~22 to avoid branch mispredicts
            for(int j=0; j<4; ++j)
            {
                counter[j]=_mm256_sub_epi8(counter[j],           // count -= 0 or -1
                                           _mm256_cmpeq_epi8(*simd_vector, C));
                ++simd_vector;
            }
        }

        // only need one outer accumulator: OoO exec hides the latency of adding into it
        sum64 = _mm256_add_epi64(sum64, hsum_epu8_epu64(counter[0]));
        sum64 = _mm256_add_epi64(sum64, hsum_epu8_epu64(counter[1]));
        sum64 = _mm256_add_epi64(sum64, hsum_epu8_epu64(counter[2]));
        sum64 = _mm256_add_epi64(sum64, hsum_epu8_epu64(counter[3]));
    }

    uint64_t sum = hsum_epu64_scalar(sum64);

    // TODO add up remaining bytes with sum.
    // Including a rolled-up vector loop before going scalar
    //  because we're potentially a *long* way from the end

    // Maybe put some logic into the main loop to shorten the 255 inner iterations
    // if we're close to the end.  A little bit of scalar work there shouldn't hurt every 255 iters.

    return sum;
}

Godbolt链接:https://godbolt.org/z/do5e3-(clang在展开最内层循环时略优于gcc:gcc包含一些无用的vmovdqa指令,如果数据在L1d缓存中热门,则会限制前端的瓶颈,防止我们接近每个时钟运行2倍32字节负载)


1
可以并且应该使用 _mm256_sad_epu8(counter, _mm256_setzero_si256()) 进行到 epu64 的扩展,然后使用 _mm256_add_epi64 将其加入一个向量中,在最后进行 hsum。 - Peter Cordes
1
我添加了代码来实现hsum和外部循环大小限制。请注意,clang使用索引寻址模式,因此它与gcc在Haswell/Skylake上运行的2个加载每个时钟周期的距离相同。:( 它们将从“vpcmpeqb”中解开并成为单独的uops在发射阶段。将其写为循环边界作为指针比较可能是更好的选择,并让clang执行纯指针增量而不是愚蠢的索引。例如:const char *endp = min(buf + size, buf + 4*255*32)或其他类似方式。 - Peter Cordes
感谢@PeterCordes的改进!我猜对于gcc来说,手动展开内部循环(即创建4个变量而不是数组)会更好。vpsadbw的技巧很棒。 - chtz
1
是的,这可能有助于GCC避免愚蠢的vmovdqa指令。如果你好奇的话,值得一试。或者报告一个错过优化的bug;它已经优化掉了任何对4个向量数组的存储/重新加载,而这显然是它应该能够优化的内容。无论如何,gcc的-funroll-loops只能手动启用或作为-fprofile-use的一部分启用;对于大型代码库,展开每个循环会带来更多的伤害而不是帮助,但是profile-use将识别热点循环并展开它们。我认为展开也可以让它避免额外的movdqa。 - Peter Cordes
“vpsadbw”技巧相对来说在8位数据求和方面比较出名,即使是有符号数也值得使用。通过异或进行范围移位,最后再减去“16或32 * 128”的偏置量。我认为Agner Fog的优化指南中提到了它,或者至少他的VectorClass库使用了它。 - Peter Cordes

3
如果您不坚持仅使用SIMD指令,可以结合VPMOVMSKB指令和POPCNT指令使用。前者将每个字节的最高位组合成32位整数掩码,后者计算此整数中的1位(即字符匹配的数量)。
int couter=0;
for(i=0; i<size; i+=32) {
  ...
  couter += 
    _mm_popcnt_u32( 
      (unsigned int)_mm256_movemask_epi8( 
        _mm256_cmpeq_epi8( C, *((__m256i*)(vector+i) ))
      ) 
    );
  ...
}    

我没有测试这个解决方案,但你应该能够理解要点。

我在OP的另一个已删除的问题中也有同样的想法。GMTA。 - Shawn
1
在内部循环中执行 _mm256_movemask_epi8_mm_popcnt_u32 比执行 _mm256_sub_epi8 效率要低得多。 - chtz
我想是的。但由于它的简单性,这也是一个值得提及的选择。 - zx485
1
可能作为清理循环的一部分有用,或者用于未对齐的开始/结束,在您弹出计数之前移出一些位,使用从重叠计算的移位计数。否则,更合理的“简单”版本是将psadbw epu8->epu64 hsum放在内部循环中,并使用_mm256_add_epi64。这只比高效方法多1个向量指令,而不是2个(vpcmpeqb + vpmovmskb + popcnt + add vs. vpcmpeqb (+vpsadbw) + vpsubb / q)。 - Peter Cordes

3

可能是最快的:memcount_avx2memcount_sse2

size_t memcount_avx2(const void *s, int c, size_t n) 
{    
  __m256i cv = _mm256_set1_epi8(c), 
          zv = _mm256_setzero_si256(), 
         sum = zv, acr0,acr1,acr2,acr3;
  const char *p,*pe;    

  for(p = s; p != (char *)s+(n- (n % (252*32)));) 
  { 
    for(acr0 = acr1 = acr2 = acr3 = zv, pe = p+252*32; p != pe; p += 128) 
    {
      acr0 = _mm256_sub_epi8(acr0, _mm256_cmpeq_epi8(cv, _mm256_lddqu_si256((const __m256i *)p))); 
      acr1 = _mm256_sub_epi8(acr1, _mm256_cmpeq_epi8(cv, _mm256_lddqu_si256((const __m256i *)(p+32)))); 
      acr2 = _mm256_sub_epi8(acr2, _mm256_cmpeq_epi8(cv, _mm256_lddqu_si256((const __m256i *)(p+64)))); 
      acr3 = _mm256_sub_epi8(acr3, _mm256_cmpeq_epi8(cv, _mm256_lddqu_si256((const __m256i *)(p+96)))); 
      __builtin_prefetch(p+1024);
    }
    sum = _mm256_add_epi64(sum, _mm256_sad_epu8(acr0, zv));
    sum = _mm256_add_epi64(sum, _mm256_sad_epu8(acr1, zv));
    sum = _mm256_add_epi64(sum, _mm256_sad_epu8(acr2, zv));
    sum = _mm256_add_epi64(sum, _mm256_sad_epu8(acr3, zv));
  } 

  for(acr0 = zv; p+32 < (char *)s + n; p += 32)  
    acr0 = _mm256_sub_epi8(acr0, _mm256_cmpeq_epi8(cv, _mm256_lddqu_si256((const __m256i *)p))); 
  sum = _mm256_add_epi64(sum, _mm256_sad_epu8(acr0, zv));

  size_t count = _mm256_extract_epi64(sum, 0) 
               + _mm256_extract_epi64(sum, 1) 
               + _mm256_extract_epi64(sum, 2) 
               + _mm256_extract_epi64(sum, 3);  

  while(p != (char *)s + n) 
      count += *p++ == c;
  return count;
}

基准测试:Skylake i7-6700 - 3.4GHz - GCC 8.3:

memcount_avx2:28 GB/s
memcount_sse:23 GB/s
char_count_AVX2:23 GB/s(来自 此帖


你可以使用 _mm256_sub_epi8 来累加 cmpeq 的结果,而不是在外部循环中浪费指令。此外,这段代码过于紧凑,缩进不正确。(也许是 SO markdown 中的制表符与空格问题?)我花了一些时间才找到你在内部循环迭代之间将 acr0..3 清零的位置;将它们声明为 外部 循环内会更有意义。我认为没有任何编译器支持 AVX2 但不支持 C99。我还会在单独的源代码行上进行结束指针计算。 - Peter Cordes
我不明白你所说的“浪费指令”是什么意思。 - powturbo
1
如果您使用 acr0 = _mm256_sub_epi8(acr0, cmp(...)),那么外部循环可以直接使用 acr0 而不是 _mm256_sub_epi8(zv, acr0)。在内部循环中使用 x -= -1 而不是 x += -1。在您的版本中,sub 是一条浪费的指令。 - Peter Cordes
1
谢谢Peter,我已经做出了更改并进行了新的基准测试。 - powturbo

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接