使用AVX收集半精度浮点数值

3
使用AVX/AVX2指令,我可以使用以下函数收集8个值的集合,这些值可以是1、2或4字节整数,或者是4字节浮点数:
_mm256_i32gather_epi32()
_mm256_i32gather_ps()
但是目前,我有一个情况,需要加载在nvidia GPU上生成并存储为FP16值的数据。如何进行这些值的矢量化加载?
到目前为止,我找到了_mm256_cvtph_ps()内置函数。
然而,该内置函数的输入是__m128i值,而不是__m256i值。
查看Intel Intrinsics Guide,我没有找到将8个值存储到_mm128i寄存器中的收集操作。
如何将FP16值收集到__m256寄存器的8个通道中? 是否可以将它们作为2字节短整数向量加载到__m256i中,然后以某种方式将其减少到__m128i值,以传递到转换内在函数中? 如果是这样,我还没有找到执行此操作的内在函数。

更新

我尝试了@peter-cordes建议的强制转换,但结果是虚假的。 另外,我不明白那怎么能行?

我的2字节int值存储在__m256i中:

0000XXXX 0000XXXX 0000XXXX 0000XXXX 0000XXXX 0000XXXX 0000XXXX 0000XXXX

所以我如何将其简单地转换为需要紧密打包的__m128i,如下所示:

XXXX XXXX XXXX XXXX XXXX XXXX XXXX XXXX

强制转换会做到吗?

我的当前代码:

__fp16* fielddensity = ...
__m256i indices = ...
__m256i msk = _mm256_set1_epi32(0xffff);
__m256i d = _mm256_and_si256(_mm256_i32gather_epi32(fielddensity,indices,2), msk);
__m256 v = _mm256_cvtph_ps(_mm256_castsi256_si128(d));

但结果似乎不是8个正确的值。我认为每隔两个就有一个对我来说是虚假的?

3
x86 CPU中没有对比32位小的元素进行gather(或scatter)的硬件支持。 如果您确实需要对非连续值进行gather,那么您可能需要将8个32位元素聚集在一起并将它们向下移动到一个__m256i底部的8个16位元素中,并将其用作__m128i(使用强制类型转换)。注意,收集数组的顶部元素不能越过未映射的页面。是的,x86仅支持将半精度浮点数转换为/从单精度浮点数(直到某个未来的AVX512)。 - Peter Cordes
1
对于16位整数的收集部分:为16位整数收集AVX2&512内在函数? - Peter Cordes
@PeterCordes 谢谢,我对你的第一条陈述感到困惑。只要我在_mm256_i32gather_epi32()中使用比例值“1”并在之后屏蔽所有高位,我就可以收集8个字节。我测试过了。我相当确信,在比例为2的情况下,我也可以对16位整数采取同样的方法。关于转换:我可以在__m256i值上执行(__m128i)吗?我会尝试的。 - Bram
2
为了可移植性,您应该使用 _mm256_castsi256_si128__m256i 转换为 __m128i(尽管 C 风格的转换在大多数编译器上都可以工作)。 - chtz
3
根据我理解这条指令的意思是,你正在收集8个不对齐的双字。当然,你可以忽略或掩盖除低字节以外的所有内容,或者像Peter建议的那样重新排列它们。 - Nate Eldredge
显示剩余2条评论
1个回答

3

16位值确实没有 gather 指令,因此需要收集32位值并忽略其中的一半(并确保不会意外从无效内存中读取)。另外,_mm256_cvtph_ps() 需要所有输入值都在较低的128位车道中,不幸的是,目前还没有跨车道的16位洗牌指令(直到 AVX512)。

然而,假设你只有有限的输入值,你可以进行一些位操作(避免使用 _mm256_cvtph_ps())。如果将半精度值加载到32位寄存器的上半部分,则可以执行以下操作:

SEEEEEMM MMMMMMMM XXXXXXXX XXXXXXXX  // input Sign, Exponent, Mantissa, X=garbage

将其算术向右移3位(这样可以保持符号位在需要的位置):

SSSSEEEE EMMMMMMM MMMXXXXX XXXXXXXX 

使用 0b1000'11111'11111111111'0000000000000 来屏蔽掉多余的符号位和底部的垃圾信息。

S000EEEE EMMMMMMM MMM00000 00000000

这将是一个有效的单精度浮点数,但指数会偏移 112 = 127-15(偏差之间的差异),即您需要将这些值乘以 2 ** 112(这可以与任何后续操作结合使用,您打算稍后执行)。请注意,这也会将次标准浮点16值转换为相应的次标准浮点32值(这些值也偏移了 2 ** 112 的因子)。
未经过测试的内部版本:
__m256 gather_fp16(__fp16 const* fielddensity, __m256i indices){
  // subtract 2 bytes from base address to load data into high parts:
  int32_t const* base = (int32_t const*) ( fielddensity - 1);

  // Gather 32bit values.
  // Be aware that this reads two bytes before each desired value,
  // i.e., make sure that reading fielddensitiy[-1] is ok!
  __m256i d = _mm256_i32gather_epi32(base, indices, 2);

  // shift exponent bits to the right place and mask away excessive bits:
  d = _mm256_and_si256(_mm256_srai_epi32(d, 3), _mm256_set1_epi32(0x8fffe000));

  // scale values to compensate bias difference (could be combined with subsequent operations ...)
  __m256 two112 = _mm256_castsi256_ps(_mm256_set1_epi32(0x77800000)); // 2**112
  __m256 f = _mm256_mul_ps(_mm256_castsi256_ps(d), two112);

  return f;
}

需要有限制,次正规是否特殊呢?我认为可能不是。但如果您尝试使用整数加法重新调整指数字段而不是FP乘法,则会出现这种情况。 - Peter Cordes
1
子规范应该可以工作,因为位移将把它们转换为相应的float32子规范(与float16子规范相差2 ** 122倍)。但我实际上没有测试过这一点。如果没有子规范输入,则最终乘法确实也可以通过整数加法完成。浮点乘法的另一个优点是它可以与某些后续浮点操作组合(可能是FMA)。 - chtz
感谢您发现了这个122个错别字(我在源代码注释中也犯了同样的错误——但是常量应该是正确的(也许写成(127+127-15)<<23会更好)。 - chtz
1
也许值得在代码块中添加一条注释,说明在每个元素之前加载2个字节。并且在文本中更明确地说明后果:如果未经过映射的页面先于页面对齐的数组,则可能会导致其破裂,如果您收集元素0,则可能会出现这种情况。对于那些还没有真正理解这是什么或者没有考虑到更广泛的元素后果的新手来说,这可能很容易被忽视。顺便说一句,这是个好主意,比我想用vpblendw 2向量+ vpshufb + vextracti128来喂2x vcvtph2ps,或者类似的变化要好得多。 - Peter Cordes

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接