快速模12算法,用于将4个uint16_t打包成一个uint64_t。

8

Consider the following union:

union Uint16Vect {
    uint16_t _comps[4];
    uint64_t _all;
};

你是否有一种快速算法来判断每个组件是否等于12的模1?

一个朴素的代码序列是:

Uint16Vect F(const Uint16Vect a) {
    Uint16Vect r;
    for (int8_t k = 0; k < 4; k++) {
        r._comps[k] = (a._comps[k] % 12 == 1) ? 1 : 0;
    }
    return r;
}

1
这很慢吗? - user1196549
4
三元运算符的无用使用 -- 当 == 操作符扩展到整数类型时,它已经返回 1 表示真和 0 表示假。 - Ben Voigt
2
根据我的计算,如果你将你的16位数字乘以43691(0xAAAB)并向下移位19位,你会得到与除以12相同的结果。如果你将此乘以12并从原始数字中减去,应该可以得到你的“mod12”。中间结果适合32位,所以你可能可以在128位寄存器中做4个。我不知道是否可以使用AVX等一次做8个,但应该很容易找出(似乎你在那方面有经验)。真正的问题是这是否节省了你任何时间...如果它被矢量化,我认为应该会节省时间。 - yzt
3
@RaymondChen:是的,但这并不能使程序员免于阅读额外的代码,觉得它没意义,想知道其中的诀窍,最后决定它完全是多余的。当优化器不关心时,代码应该被写成最易读的(甚至在优化器关心时,如果它不是代码的性能关键部分的话)。 - Ben Voigt
1
@Ben,有人认为显式的 bool-expression ? 1 : 0 更易读,因为它不依赖于双关语(即比较运算符的结果实际上不是布尔值,而只是整数0或1)。在C++中,这是一种隐式类型转换。在其他语言中,您可能会得到类型不匹配的错误,甚至结果可能是 -1 - Raymond Chen
显示剩余5条评论
5个回答

12
编译器将对除以常数进行优化,转换为乘以倒数或乘法逆元。例如,x/12将被优化为bool h(uint16_t x) { return x % 12 == 1; } h(unsigned short): movzx eax, di imul eax, eax, 43691 ; = 0xFFFF*8/12 + 1 shr eax, 19 lea eax, [rax+rax*2] sal eax, 2 sub edi, eax cmp di, 1 sete al ret

由于SSE/AVX中有乘法指令,因此可以轻松进行向量化。此外,x = (x % 12 == 1)?1:0; 可以简化为x =(x%12 == 1),然后转换为x =(x-1)%12 == 0,从而避免从常量表加载值1进行比较。您可以使用向量扩展 ,让gcc自动生成代码。

typedef uint16_t ymm32x2 __attribute__((vector_size(32)));
ymm32x2 mod12(ymm32x2 x)
{
    return !!((x - 1) % 12);
}

以下是gcc的mod12(unsigned short __vector(16)): vpcmpeqd ymm3, ymm3, ymm3 ; ymm3 = -1 vpaddw ymm0, ymm0, ymm3 vpmulhuw ymm1, ymm0, YMMWORD PTR .LC0[rip] ; multiply with 43691 vpsrlw ymm2, ymm1, 3 vpsllw ymm1, ymm2, 1 vpaddw ymm1, ymm1, ymm2 vpsllw ymm1, ymm1, 2 vpcmpeqw ymm0, ymm0, ymm1 vpandn ymm0, ymm0, ymm3 ret

Clang和ICC不支持向量类型的!!,因此您需要改为(x - 1) % 12 == 0。不幸的是,编译器似乎不支持__attribute__((vector_size(8)))来发出MMX指令。但现在您应该使用SSE或AVX。

如上面的Godbolt链接所示,x % 12 == 1的输出更短,但您需要一个包含1s的表进行比较,这可能更好也可能不是。编译器无法完全优化手写代码,因此可以尝试使用内部函数将代码手动向量化。检查哪一个在您的情况下更快。

更好的方法是((x * 43691) & 0x7ffff) < 43691,或者如nwellnhof的答案中所述,x * 357913942 < 357913942,这也应该很容易向量化。


或者,对于像这样的小输入范围,可以使用查找表。基本版本需要一个65536元素的数组。

#define S1(x) ((x) + 0) % 12 == 1, ((x) + 1) % 12 == 1, ((x) + 2) % 12 == 1, ((x) + 3) % 12 == 1, \
              ((x) + 4) % 12 == 1, ((x) + 4) % 12 == 1, ((x) + 6) % 12 == 1, ((x) + 7) % 12 == 1
#define S2(x) S1((x + 0)*8), S1((x + 1)*8), S1((x + 2)*8), S1((x + 3)*8), \
              S1((x + 4)*8), S1((x + 4)*8), S1((x + 6)*8), S1((x + 7)*8)
#define S3(x) S2((x + 0)*8), S2((x + 1)*8), S2((x + 2)*8), S2((x + 3)*8), \
              S2((x + 4)*8), S2((x + 4)*8), S2((x + 6)*8), S2((x + 7)*8)
#define S4(x) S3((x + 0)*8), S3((x + 1)*8), S3((x + 2)*8), S3((x + 3)*8), \
              S3((x + 4)*8), S3((x + 4)*8), S3((x + 6)*8), S3((x + 7)*8)

bool mod12e1[65536] = {
    S4(0U), S4(8U), S4(16U), S4(24U), S4(32U), S4(40U), S4(48U), S4(56U)
}

只需将x % 12 == 1替换为mod12e1[x]即可使用。当然,这也可以进行矢量化处理。

但由于结果只有1或0,因此您还可以使用65536位数组将其大小减小到仅8KB。


您还可以通过4和3的整除性来检查12的整除性。显然,4的整除性是微不足道的。可以通过多种方式计算3的整除性:

  • 一种方法是计算奇数位数字之和与偶数位数字之和的差值,例如גלעד ברקן的答案中所示,并检查它是否可被3整除。

  • 或者,您可以检查在基于2的2k(如基于4、16、64...的基数)下的数字之和,以查看它是否可被3整除。

    这种方法之所以有效,是因为在基数为b时,要检查任何除数n是否能被b - 1整除,只需检查数字之和是否能被n整除。以下是实现:

  void modulo12equals1(uint16_t d[], uint32_t size) {
      for (uint32_t i = 0; i < size; i++)
      {
          uint16_t x = d[i] - 1;
          bool divisibleBy4 = x % 4 == 0;
          x = (x >> 8) + (x & 0x00ff); // max 1FE
          x = (x >> 4) + (x & 0x000f); // max 2D
          bool divisibleBy3 = !!((01111111111111111111111ULL >> x) & 1);
          d[i] = divisibleBy3 && divisibleBy4;
      }
  }

将3整除的功劳归功于Roland Illig

由于自动向量化的汇编输出过长,您可以在如何判断一个二进制数是否能被3整除?

  • 确定一个二进制数是否能被3整除
  • 位表示和3的整除性
  • 构建可判断3的整除性的电路
  • 检查一个数字是否能被3整除
  • 检查一个数字是否能被3整除的逻辑?

  • 我的第一反应是使用查找表,我在评论中提出并撤回了这个想法,因为我认为它可能不如操作快。 - גלעד ברקן
    用类似for(i=1;i<65536;i+=12) mod12e1[i]=true;这样的方式创建查找表是否更好?因为它不使用%运算符。 - kingW3
    @kingW3,我的表格在编译时构建,而你的是在运行时填充。 - phuclv

    2
    如果仅使用位运算和popcount有助于限制操作,我们可以观察到一个有效的候选者必须通过两个测试,因为减去1必须意味着可被4和3整除。首先,最后两位必须是01。然后,通过从偶数位置的popcount减去奇数位置的popcount可以找出可被3整除性。

    const evenMask = parseInt('1010101010101010', 2);
    // Leave out first bit, we know it will be zero
    // after subtracting 1
    const oddMask = parseInt('101010101010100', 2);
    
    console.log('n , Test 1: (n & 3)^3, Test 2: popcount diff:\n\n');
    
    for (let n=0; n<500; n++){
      if (n % 12 == 1)
        console.log(
          n,
          (n & 3)^3,
          popcount(n & evenMask) - popcount(n & oddMask))
    }
    
    // https://dev59.com/N1gQ5IYBdhLWcg3wOBSw
    function popcount(n) {
      var tmp = n;
      var count = 0;
      while (tmp > 0) {
        tmp = tmp & (tmp - 1);
        count++;
      }
      return count;
    }


    不幸的是,对于SIMD寄存器并没有popcnt - phuclv
    @phuclv 谢谢您的评论。我对此并不是很了解,但OP提到了AVX2 - 这些关于popcount的文章(和其他文章)可能相关吗?https://lemire.me/en/publication/arxiv1611.07612/ 和 https://news.ycombinator.com/item?id=11277891 - גלעד ברקן
    你可以使用一些额外的指令(或使用AVX512VPOPCNTDQ处理dword元素)进行SIMD popcount。你可能需要将vpsadbw(水平64位加法)替换为16位元素的水平加法(例如pmaddubsw),以获得每个16位块的popcnt。如果你拥有本地的vpopcntw (AVX512BITALG)(每个16位元素的SIMD popcnt),它甚至可能比使用高半乘法的普通定点乘法逆运算更快。 - Peter Cordes

    2
    最好的我能够想到的答案是"最初的回答"。
    uint64_t F(uint64_t vec) {
        //512 = 4 mod 12  -> max val 0x3FB
        vec = ((vec & 0xFE00FE00FE00FE00L) >> 7) + (vec & 0x01FF01FF01FF01FFL);
        //64 = 4 mod 12 -> max val 0x77
        vec = ((vec & 0x03C003C003C003C0L) >> 4) + (vec & 0x003F003F003F003FL);
        //16 = 4 mod 12 -> max val 0x27
        vec = ((vec & 0x0070007000700070L) >> 2) + (vec & 0x000F000F000F000FL);
        //16 = 4 mod 12 -> max val 0x13
        vec = ((vec & 0x0030003000300030L) >> 2) + (vec & 0x000F000F000F000FL);
        //16 = 4 mod 12 -> max val 0x0f
        vec = ((vec & 0x0030003000300030L) >> 2) + (vec & 0x000F000F000F000FL);
    
        //Each field is now 4 bits, and only 1101 and 0001 are 1 mod 12.
        //The top 2 bits must be equal and the other2 must be 0 and 1
    
        return vec & ~(vec>>1) & ~((vec>>2)^(vec>>3)) & 0x0001000100010001L;
    }
    

    2
    返回值应该是赋值(=)还是相等(==)? - Ben Voigt

    2
    最近Daniel Lemire的博客上有一篇关于快速余数计算和可除性检查的文章。例如,您可以使用((x * 43691) & 0x7ffff) < 43691或假设32位操作使用x * 357913942 < 357913942来检查12的可除性。这应该很容易并行化,但它需要32位乘法,不像phuclv答案中的代码。

    0
    请注意:x mod 12 == 1 意味着 x mod 4 == 1,而后者非常便宜。
    因此:
    is_mod_12 = ((input & 3) == 1) && ((input % 12) == 1);
    

    如果input mod 4经常不是1,那么这将节省您很多模运算。然而,如果input mod 4通常是1,则会稍微降低性能。


    & 周围缺少括号。由于模运算被优化为乘法,因此这个额外的 & 操作可能会导致反优化(即使输入均匀分布 - 我不确定,需要进行基准测试)。 - geza
    1
    @geza:按位与相比乘法的优势小于除法,但仍然显著。 - Ben Voigt
    4
    但是还有一个额外的分支,这可能会非常昂贵(gcc和clang都会将其编译为分支)。这个建议的性能高度取决于输入。如果该分支无法预测得很好,那么这段代码的运行速度比原始代码要慢得多。 - geza
    @geza:当然,我在第一个版本的答案中就已经意识到了潜在的非最优化问题。但是分支与模数是否被转换为使用乘法指令无关。 - Ben Voigt
    OP正在寻找带有AVX2的SIMD(但未能标记为此)。这意味着要做到一切“无分支”,因此我们始终需要%12的结果:/您可以尝试检查所有16x 2字节元素中(input&3)==1是否为false,以提前跳过%12,但根据数据,这不太可能跳过许多完整向量。您的向量越宽,元素越小,提前退出对于向量中的所有元素起作用的次数就越少。 - Peter Cordes

    网页内容由stack overflow 提供, 点击上面的
    可以查看英文原文,
    原文链接