快速计算数组中值为零的字节数

8

如何快速计算一个大的、连续数组中零值字节的数量?(或者反过来,非零字节的数量。)我所说的大数组是指216字节或更大。该数组的位置和长度可以包含任意字节对齐。

朴素的方法:

int countZeroBytes(byte[] values, int length)
{
    int zeroCount = 0;
    for (int i = 0; i < length; ++i)
        if (!values[i])
            ++zeroCount;

    return zeroCount;
}

对于我的问题,我通常只维护zeroCount并根据特定更改来更新它。然而,我想要一个快速的、通用的方法,在values发生任意批量更改后重新计算zeroCount。我相信有一种位操作方法可以更快地实现这个目标,但不幸的是,我只是个新手。

编辑:有几个人问到了零检查数据的性质,所以我会描述一下它。 (不过,如果解决方案仍然是通用的,那就太好了。)

基本上,想象一个由体素(例如Minecraft)组成的世界,其中程序生成的地形被分成立方体块,或者有效地作为三维数组索引的内存页面。每个体素都被视为唯一的字节,对应于唯一的材料(空气,石头,水等)。许多块只包含空气或水,而其他块则包含大量的2-4个体素的不同组合(泥土,沙子等),有效地有2-10%的体素是随机的离群值。存在大量体素的体素往往沿每个轴高度聚集。

看起来似乎零字节计数方法在许多无关的情况下都是有用的。因此,需要一个通用的解决方案。


2
寻找“人口计数”硬件指令和编译器内部函数(例如,这里是MSVC),并将它们应用于最大的字长。 - Kerrek SB
2
可能有一些利用特定架构的方法(例如特定操作码或x86上的SSE),但总的来说,我怀疑没有任何更快的方法。即使是查找表(例如16位块),也可能不会有所帮助,因为它们只会使缓存失效。 - Oliver Charlesworth
3
最优解可能取决于您是否认为零很常见。如果不是,一种按位“有零”操作,每次对一个大的(例如64位)字进行处理,可能是最佳方法,允许您跳过没有零的大段内容。然而,要注意别名违规问题... - R.. GitHub STOP HELPING ICE
也许一些棘手的实现,例如对于 strlen,可能会有所帮助...(通过 uint32_t* 迭代而不是天真的 uint8_t)。 - Jarod42
1
目前不将此作为答案发布,但你是否查看过位操作技巧中的相关代码 - Hasturkun
显示剩余8条评论
6个回答

7
这是如何使用SIMD计算字符出现次数的一个特殊情况,其中c=0表示要匹配的char(byte)值。请参阅该Q&A以获取手动向量化的AVX2实现char_count (char const* vector, size_t size, char c);,其内部循环比此更紧凑,避免将每个0/-1匹配向量分别减少为标量。
这将以O(n)的速度进行,因此您能做的最好的事情是减少常量。一个快速修复方法是删除分支。如果零是随机分布的,则此操作可以提供与下面我的SSE版本一样快的结果。这很可能是因为GCC向量化了这个循环。然而,对于长时间的零运行或零的随机密度小于1%的情况,下面的SSE版本仍然更快。
int countZeroBytes_fix(char* values, int length) {
    int zeroCount = 0;
    for(int i=0; i<length; i++) {
        zeroCount += values[i] == 0;
    }
    return zeroCount;
}

我最初认为零的密度很重要。至少在使用SSE时,事实并非如此。使用SSE速度更快,与密度无关。

编辑:实际上,它确实取决于密度,只是零的密度必须比我预期的小。 1/64的零(1.5%的零)是1/4 SSE寄存器中的一个零,因此分支预测效果不佳。然而,1/1024的零(0.1%的零)速度更快(请参见时间表)。

如果数据中有长时间的零,则SIMD速度更快。

您可以将16个字节打包到SSE寄存器中。然后,您可以使用_mm_cmpeq_epi8一次比较所有16个字节是否为零。然后,为了处理零的运行,您可以对结果使用_mm_movemask_epi8,大多数情况下它将为零。在这种情况下,您可以获得高达16倍的加速(对于第一半为1,第二半为零的情况,我获得了超过12倍的加速)。

以下是10000次重复的2^16字节的时间表(以秒为单位)。

                     1.5% zeros  50% zeros  0.1% zeros 1st half 1, 2nd half 0
countZeroBytes       0.8s        0.8s       0.8s        0.95s
countZeroBytes_fix   0.16s       0.16s      0.16s       0.16s
countZeroBytes_SSE   0.2s        0.15s      0.10s       0.07s

您可以在http://coliru.stacked-crooked.com/a/67a169ddb03d907a中查看最近1/2个零的结果。

#include <stdio.h>
#include <stdlib.h>
#include <emmintrin.h>                 // SSE2
#include <omp.h>

int countZeroBytes(char* values, int length) {
    int zeroCount = 0;
    for(int i=0; i<length; i++) {
        if (!values[i])
            ++zeroCount;
    }
    return zeroCount;
}

int countZeroBytes_SSE(char* values, int length) {
    int zeroCount = 0;
    __m128i zero16 = _mm_set1_epi8(0);
    __m128i and16 = _mm_set1_epi8(1);
    for(int i=0; i<length; i+=16) {
        __m128i values16 = _mm_loadu_si128((__m128i*)&values[i]);
        __m128i cmp = _mm_cmpeq_epi8(values16, zero16);
        int mask = _mm_movemask_epi8(cmp);
        if(mask) {
            if(mask == 0xffff) zeroCount += 16;
            else {
                cmp = _mm_and_si128(and16, cmp); //change -1 values to 1
                //hortiontal sum of 16 bytes
                __m128i sum1 = _mm_sad_epu8(cmp,zero16);
                __m128i sum2 = _mm_shuffle_epi32(sum1,2);
                __m128i sum3 = _mm_add_epi16(sum1,sum2);
                zeroCount += _mm_cvtsi128_si32(sum3);
            }
        }
    }
    return zeroCount;
}

int main() {
    const int n = 1<<16;
    const int repeat = 10000;
    char *values = (char*)_mm_malloc(n, 16);
    for(int i=0; i<n; i++) values[i] = rand()%64;  //1.5% zeros
    //for(int i=0; i<n/2; i++) values[i] = 1;
    //for(int i=n/2; i<n; i++) values[i] = 0;
    
    int zeroCount = 0;
    double dtime;
    dtime = omp_get_wtime();
    for(int i=0; i<repeat; i++) zeroCount = countZeroBytes(values,n);
    dtime = omp_get_wtime() - dtime;
    printf("zeroCount %d, time %f\n", zeroCount, dtime);
    dtime = omp_get_wtime();
    for(int i=0; i<repeat; i++) zeroCount = countZeroBytes_SSE(values,n);
    dtime = omp_get_wtime() - dtime;
    printf("zeroCount %d, time %f\n", zeroCount, dtime);       
}

2
您只需要每127次水平求和(使用SAD)一下,以避免溢出。在此之前,您可以使用PADDB对比较结果进行累加,将其视为“-1”或“0”的向量。即使在全零或全非零的情况下,这也比这个更有效率。即使保持简单并且每次迭代都对64位计数器的向量进行水平求和也很好。(即循环内部PCMPEQB / PSADBW / PADDQ,再加上一个或两个MOVDQA)。 - Peter Cordes

5
我提供了这个OpenMP实现,可以利用每个处理器的本地缓存来并行读取数组。
nzeros_total = 0;
#pragma omp parallel for reduction(+:nzeros_total)
    for (i=0;i<NDATA;i++)
    {
        if (v[i]==0)
            nzeros_total++;
    }

一个快速的基准测试,包括运行1000次for循环的天真实现(与OP在问题中编写的相同)与OpenMP实现进行比较,每种方法都运行1000次,使用65536个int数组,零值元素概率为50%,在QuadCore CPU上使用Windows 7进行编译,并使用VStudio 2012 Ultimate编译,得到以下数字:

               DEBUG               RELEASE
Naive method:  580 microseconds.   341 microseconds.
OpenMP method: 159 microseconds.    99 microseconds.

注意:我尝试了#pragma loop (hint_parallel(4)),但显然这并没有使朴素版本表现更好,所以我的猜测是编译器已经应用了这个优化,或者根本无法应用。另外,#pragma loop (no_vector)并没有使朴素版本表现更差。


“在所有缓存中”的假设是一个很大的假设,但如果它成立,这是一种简单而有效的技术。” - Oliver Charlesworth
1
如果你移除分支,只执行 nzeros_total += v[i] == 0,这可能会更快。 - Z boson
虽然这似乎很有效,但我目前不幸没有多余的处理器。谢谢你。 - Philip Guin
我将此标记为最佳答案,因为它似乎既具有可移植性又高效。然而,如果有人发明了一些既具有可移植性、快速、非平凡且非并行的东西,我会很高兴接受它。(除了惊叹之外!) - Philip Guin

2

您还可以使用POPCNT指令来返回置位位数。这样可以进一步简化代码,并通过消除不必要的分支来提高速度。以下是AVX2和POPCNT的示例:

#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include "immintrin.h"

int countZeroes(uint8_t* bytes, int length)
{
    const __m256i vZero = _mm256_setzero_si256();
    int count = 0;
    for (int n = 0; n < length; n += 32)
    {
        __m256i v = _mm256_load_si256((const __m256i*)&bytes[n]);
        v = _mm256_cmpeq_epi8(v, vZero);
        int k = _mm256_movemask_epi8(v);
        count += _mm_popcnt_u32(k);
    }
    return count;
}

#define SIZE 1024

int main()
{
    uint8_t bytes[SIZE] __attribute__((aligned(32)));

    for (int z = 0; z < SIZE; ++z)
        bytes[z] = z % 2;

    int n = countZeroes(bytes, SIZE);
    printf("%d\n", n);

    return 0;
}

1
您可以使用cmpeq_epi8结果作为0 / -1整数,累加到向量累加器中,并仅在最后进行添加,从而收紧内部循环:如何使用SIMD计算字符出现次数。 在内部循环中的工作量比movemask + popcnt +标量加法少,但需要嵌套循环以避免大“长度”时的溢出。 - Peter Cordes

1
对于0很常见的情况,一次检查64个字节会更快,并且仅在跨度非零时才检查字节。如果0很少出现,则这将更加昂贵。此代码假定大块可被64整除。这还假定memcmp是您可以获得的最有效率的方式。
int countZeroBytes(byte[] values, int length)
{
    static const byte zeros[64]={};

    int zeroCount = 0;
    for (int i = 0; i < length; i+=64)
    {
        if (::memcmp(values+i, zeros, 64) == 0)
        {
             zeroCount += 64;
        }
        else
        {
               for (int j=i; j < i+64; ++j)
               {
                     if (!values[j])
                     {
                          ++zeroCount;
                     }
               }
        }
    }

    return zeroCount;
}

嗯,有时候数据完全为零,其他时候很少。但我喜欢这个方法的原因是,如果先前的“零计数”很高,我可以使用它,甚至在“零计数”超过某个阈值时,在固定间隔内切换到另一种方法。 - Philip Guin

1
暴力计算零字节:使用向量比较指令,如果该字节为0,则将向量的每个字节设置为1,否则将其设置为0。
重复255次以处理最多255 x 64字节(如果您有512位指令可用),或者255 x 32或255 x 16字节(如果您只有128位向量)。然后,您只需将255个结果向量相加即可。由于比较后的每个字节的值都为0或1,因此每个总和最多为255,因此现在您有一个64/32/16字节的向量,而不是约16,000/8,000/4,000字节。

0

避免条件并将其交换为查找和添加可能会更快:

char isCharZeroLUT[256] = { 1 }; /* 1 0 0 ... */
int zeroCount = 0;
for (int i = 0; i < length; ++i) {
    zeroCount += isCharZeroLUT[values[i]];
}

我还没有测量过差异。值得注意的是,某些编译器可以愉快地向量化足够简单的循环。


@Jongware:很有可能编译器已经在做类似的事情了,即避免条件分支。 - Oliver Charlesworth
@Jongware:有可能。我不确定是否可以通过一些位操作或内部使用条件来完成。很可能它避免了条件和查找。 - Dietmar Kühl
1
如果您的C实现具有负值的(实现定义的)右移的通常定义,则 count -= ((unsigned char)values[i]-1)>>8; 将完成它。 - R.. GitHub STOP HELPING ICE
1
@Jongware,当values[i] != 0时,您的代码将会增加,因此我认为zeroCount += (values[i] == 0)更加正确。 - phuclv
@LưuVĩnhPhúc:OP认为“计算非零”同样是一个好的解决方案。我提出的表达式是为了避免显式比较(我知道在汇编级别仍可能发生比较)。 - Jongware

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接