如何执行 _mm256_movemask_epi8 (VPMOVMSKB) 的反操作?

26

内在价值:

int mask = _mm256_movemask_epi8(__m256i s1)
创建一个掩码,其32位对应于s1每个字节的最高有效位。使用位操作(例如BMI2)操作掩码后,我想执行_mm256_movemask_epi8的反操作,即创建一个__m256i向量,其中每个字节的最高有效位包含uint32_t mask的相应位。
最好的方法是什么?
编辑: 我需要执行反向操作,因为内置函数_mm256_blendv_epi8只接受__m256i类型的掩码,而不是uint32_t类型。因此,在结果__m256i掩码中,我可以忽略除每个字节的MSB之外的位。

2
使用AVX512,您可以使用_mm256_mask_blend_epi8(__mmask32 k, __m256i a, __m256i b)函数,将整数作为掩码。 - technosaurus
参见我在一个可能的重复问题上的回答。使用vpsllvd变量移位将掩码的不同位放入每个元素的符号位中。这对于元素大小为32b非常有用,但对于8b则不是很好。 - Peter Cordes
在Intel AVX2中是否有movemask指令的反向指令?is there an inverse instruction to the movemask instruction in intel avx2?列出了不同版本(SSE和AVX)的不同元素大小。 - Peter Cordes
5个回答

19

我已在Haswell机器上实施了上述三种方法。Evgeny Kluev的方法最快(1.07秒),其次是Jason R的(1.97秒)和Paul R的(2.44秒)。下面的代码使用了-march=core-avx2 -O3优化标志进行编译。

#include <immintrin.h>
#include <boost/date_time/posix_time/posix_time.hpp>

//t_icc = 1.07 s
//t_g++ = 1.09 s
__m256i get_mask3(const uint32_t mask) {
  __m256i vmask(_mm256_set1_epi32(mask));
  const __m256i shuffle(_mm256_setr_epi64x(0x0000000000000000,
      0x0101010101010101, 0x0202020202020202, 0x0303030303030303));
  vmask = _mm256_shuffle_epi8(vmask, shuffle);
  const __m256i bit_mask(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe));
  vmask = _mm256_or_si256(vmask, bit_mask);
  return _mm256_cmpeq_epi8(vmask, _mm256_set1_epi64x(-1));
}

//t_icc = 1.97 s
//t_g++ = 1.97 s
__m256i get_mask2(const uint32_t mask) {
  __m256i vmask(_mm256_set1_epi32(mask));
  const __m256i shift(_mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0));
  vmask = _mm256_sllv_epi32(vmask, shift);
  const __m256i shuffle(_mm256_setr_epi64x(0x0105090d0004080c,
      0x03070b0f02060a0e, 0x0105090d0004080c, 0x03070b0f02060a0e));
  vmask = _mm256_shuffle_epi8(vmask, shuffle);
  const __m256i perm(_mm256_setr_epi64x(0x0000000000000004, 0x0000000100000005,
      0x0000000200000006, 0x0000000300000007));
  return _mm256_permutevar8x32_epi32(vmask, perm);
}

//t_icc = 2.44 s
//t_g++ = 2.45 s
__m256i get_mask1(uint32_t mask) {
  const uint64_t pmask = 0x8080808080808080ULL; // bit unpacking mask for PDEP
  uint64_t amask0, amask1, amask2, amask3; 
  amask0 = _pdep_u64(mask, pmask);
  mask >>= 8;
  amask1 = _pdep_u64(mask, pmask);
  mask >>= 8;
  amask2 = _pdep_u64(mask, pmask);
  mask >>= 8;
  amask3 = _pdep_u64(mask, pmask);
  return _mm256_set_epi64x(amask3, amask2, amask1, amask0);
}

int main() {
  __m256i mask;
  boost::posix_time::ptime start(
      boost::posix_time::microsec_clock::universal_time()); 
  for(unsigned i(0); i != 1000000000; ++i)
    { 
      mask = _mm256_xor_si256(mask, get_mask3(i));
    }
  boost::posix_time::ptime end(
      boost::posix_time::microsec_clock::universal_time());
  std::cout << "duration:" << (end-start) << 
    " mask:" << _mm256_movemask_epi8(mask) << std::endl;
  return 0;
}

4
感谢您跟进了所有三个建议,并提供了对结果的简洁概述!顺便问一下,您使用了哪个编译器? - Paul R
2
谢谢!我同时使用了icc和g++。我已经使用优化标志更新了时间。 - Satya Arjunan
1
就我个人而言,我在这里使用clang进行了一些基准测试,并获得了类似的结果。 - Paul R
1
clang 结果:get_mask3: 0.9968 纳秒,get_mask2: 1.7413 纳秒,get_mask1: (check = 0) 2.291 纳秒 - Paul R

10

这里有一种替代 LUT 或 pdep 指令的方法,可能会更加高效:

  1. 将您的 32 位掩码复制到某个 ymm 寄存器的低字节和同一寄存器的字节 16..19 中。您可以使用临时数组和 _mm256_load_si256。或者您可以将单个 32 位掩码的副本移动到某个 ymm 寄存器的低字节中,然后使用 VPBROADCASTD (_mm_broadcastd_epi32) 或其他广播/洗牌指令进行广播。
  2. 重新排列寄存器的字节,以使低 8 个字节(每个字节)包含掩码的低 8 位,接下来的 8 个字节-下一个 8 位,依此类推。这可以使用控制寄存器包含低 8 字节中的“0”,下一个 8 字节中的“1” 等的 VPSHUFB (_mm256_shuffle_epi8) 完成。
  3. 使用 VPOR (_mm256_or_si256)VPAND (_mm256_and_si256) 选择每个字节的适当位。
  4. 使用 VPCMPEQB (_mm256_cmpeq_epi8) 设置适当字节的 MSB。将每个字节与 0xFF 进行比较。如果要切换掩码的每个位,请在前一步上使用 VPAND 并与零进行比较。

这种方法的额外灵活性在于,您可以选择不同的控制寄存器来进行步骤#2和不同的掩码用于步骤#3以对位掩码进行洗牌(例如,您可以按相反顺序将此掩码复制到 ymm 寄存器中)。


1
只需使用 _mm256_set1_epi32,让编译器自行执行 vpbroadcastd ymm,[mem] 广播加载,如果需要的话。 - Peter Cordes
1
洗牌后,使用VPAND和VPCMPEQB实现“bitmap &(1 << bit)==(1 << bit)”。 您只需要一个向量常量。 - Peter Cordes
1
如果你想要 0/1 而不是 0/0xff,请使用 _mm256_min_epu8(and_result, _mm256_set1_epi8(1)) 替代对 AND 掩码的 cmpeq。具有非零字节的元素将具有最小值为 1,而不是 min(0,1) = 0。(这个技巧来自于 如何使用 x86 SIMD 高效地将 8 位位图转换为 0/1 整数数组) - Peter Cordes

4

我的初始方法与 @Jason R's 相似,因为这是“正常”操作的方式,但是大多数这些操作只关心高位 - 忽略所有其他位。一旦我意识到这一点,_mm*_maskz_broadcast*_epi*(mask,__m128i) 函数系列就变得最有意义了。您需要启用 -mavx512vl 和 -mavx512bw(gcc)。

要根据掩码设置每个字节的最高位的向量:

/* convert 16 bit mask to __m128i control byte mask */
_mm_maskz_broadcastb_epi8((__mmask16)mask,_mm_set1_epi32(~0))
/* convert 32 bit mask to __m256i control byte mask */
_mm256_maskz_broadcastb_epi8((__mmask32)mask,_mm_set1_epi32(~0))
/* convert 64 bit mask to __m512i control byte mask */
_mm512_maskz_broadcastb_epi8((__mmask64)mask,_mm_set1_epi32(~0))

根据掩码获取每个的最高位设置为向量:

/* convert 8 bit mask to __m128i control word mask */
_mm_maskz_broadcastw_epi16((__mmask8)mask,_mm_set1_epi32(~0))
/* convert 16 bit mask to __m256i control word mask */
_mm256_maskz_broadcastw_epi16((__mmask16)mask,_mm_set1_epi32(~0))
/* convert 32 bit mask to __m512i control word mask */
_mm512_maskz_broadcastw_epi16((__mmask32)mask,_mm_set1_epi32(~0))

要根据掩码获取每个双字的最高位设置为向量:

/* convert 8 bit mask to __m256i control mask */
_mm256_maskz_broadcastd_epi32((__mmask8)mask,_mm_set1_epi32(~0))
/* convert 16 bit mask to __m512i control mask */
_mm512_maskz_broadcastd_epi32((__mmask16)mask,_mm_set1_epi32(~0))

要根据掩码设置每个四元组字的最高位,以获得一个向量:

/* convert 8 bit mask to __m512i control mask */
_mm512_maskz_broadcastq_epi64((__mmask8)mask,_mm_set1_epi32(~0))

与此问题有关的特定函数是: _mm256_maskz_broadcastb_epi8((__mmask32)mask,_mm_set1_epi32(~0)),但为了参考/比较,我也包括了其他内容。

请注意,根据掩码,每个字节/字/...将全部为1或全部为0(不仅仅是最高位)。这对于执行矢量化的位操作(例如和另一个向量进行'&'运算以消除不需要的字节/字)也很有用。

另一个注意点:每个_mm_set1_epi32(~0)都可以/应该转换为常量(手动或由编译器),因此它应该编译为仅一个相当快的操作,尽管在测试中可能比实际生活中略快,因为常量可能会保留在寄存器中。然后这些被转换为VPMOVM2{b,w,d,q}指令

编辑:如果您的编译器不支持AVX512,则内联汇编版本应如下:

inline __m256i dmask2epi8(__mmask32 mask){
  __m256i ret;
  __asm("vpmovm2b   %1, %0":"=x"(ret):"k"(mask):);
  return ret;
}

其他指令类似。

如果您想要0 / -1,请使用 _mm256_movm_epi8(mask),而不是零掩码广播。除了-1之外的另一个选项是_mm256_maskz_mov_epi8(mask32, _mm256_set1_epi8(1))。 如果没有vpmovm2b,广播将很有趣,因为128位全1比创建512位更便宜(vpcmpeqd same,same被特殊处理为dep-breaking),但广播是只能在端口5上运行的洗牌。请参见将16位掩码转换为16字节掩码的AVX-512部分(其中大多数想要0 / 1,而不是普通的0 / -1)。 - Peter Cordes

3

我能想到的唯一相对高效的方法是使用8位查找表:进行4次8位查找,然后将结果加载到矢量中,例如:

static const uint64_t LUT[256] = { 0x0000000000000000ULL,
                                   ...
                                   0xffffffffffffffffULL };

uint64_t amask[4] __attribute__ ((aligned(32)));

uint32_t mask;
__m256i vmask;

amask[0] = LUT[mask & 0xff];
amask[1] = LUT[(mask >> 8) & 0xff];
amask[2] = LUT[(mask >> 16) & 0xff];
amask[3] = LUT[mask >> 24];
vmask = _mm256_load_si256((__m256i *)amask);

另一种方法是使用寄存器而不是临时数组,看看你的编译器是否能够做出更高效的操作,而不需要通过内存:

static const uint64_t LUT[256] = { 0x0000000000000000ULL,
                                   ...
                                   0xffffffffffffffffULL };

uint64_t amask0, amask1, amask2, amask3;

uint32_t mask;
__m256i vmask;

amask0 = LUT[mask & 0xff];
amask1 = LUT[(mask >> 8) & 0xff];
amask2 = LUT[(mask >> 16) & 0xff];
amask3 = LUT[mask >> 24];
vmask = _mm256_set_epi64x(amask3, amask2, amask1, amask0);

补充一点:一个有趣的挑战可能是使用例如Haswell BMI指令来执行等同于8 -> 64位LUT操作,从而摆脱LUT。看起来您可以使用PDEP进行此操作,例如:

const uint64_t pmask = 0x8080808080808080ULL; // bit unpacking mask for PDEP

uint64_t amask0, amask1, amask2, amask3;

uint32_t mask;
__m256i vmask;

amask0 = _pdep_u64(mask, pmask); mask >>= 8;
amask1 = _pdep_u64(mask, pmask); mask >>= 8;
amask2 = _pdep_u64(mask, pmask); mask >>= 8;
amask3 = _pdep_u64(mask, pmask);
vmask = _mm256_set_epi64x(amask3, amask2, amask1, amask0);

是的,如果可能的话,我想避免使用查找表。与我正在执行的基于寄存器的操作相比,它们非常昂贵。 - Satya Arjunan

3
这里提供另一种实现方法,适用于AVX2技术(因为在您的问题中有该标签),但尚未经过测试,因为我没有Haswell机器。它与Evgeny Kluev的答案类似,但可能需要更少的指令。但是,它需要两个常量__m256i掩码。如果您要在循环中多次执行此操作,则一次性设置这些常量的开销可能可以忽略不计。
  • 使用_mm_broadcastd_epi32()将32位掩码广播到一个ymm寄存器的所有8个插槽。

  • 创建一个包含8个32位整数的__m256i,其值为[0,1,2,3,4,5,6,7](从最低有效元素到最高有效元素)。

  • 使用该常数掩码将ymm寄存器中的每个32位整数向左旋转不同的量,使用_mm256_sllv_epi32()

  • 现在,如果我们将ymm寄存器视为包含8位整数并查看它们的MSB,则该寄存器现在保存了字节索引的MSB [7,15,23,31,6,14,22,30,5,13,21,29,4,12,20,28,3,11,19,27,2,10,18,26,1,9, 17,25,0,8,16,24](从最低有效元素到最高有效元素)。

  • 使用一个常量掩码[0x80,0x80,0x80,...]的按位AND来隔离每个字节的MSB。

  • 使用一系列洗牌和/或排列来将元素按您想要的顺序放回。不幸的是,与AVX2中的浮点值一样,没有针对8位整数的任何对任意置换(any-to-any permute)。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接