高效并行的字节级有符号整数“小于或等于”谓词计算

3
在一些上下文中,例如生物信息学,对字节大小的整数进行计算就足够了。为了获得最佳性能,许多处理器架构提供SIMD指令集(例如MMX,SSE,AVX),将寄存器分成字节、半字和字大小的组件,然后单独执行相应组件的算术、逻辑和移位操作。
然而,一些架构不提供这样的SIMD指令,需要进行模拟,这通常需要大量的位运算。目前,我正在研究SIMD比较,特别是有符号的字节大小整数的并行比较。我有一个解决方案,认为使用可移植的C代码非常高效(请参见下面的函数vsetles4)。它基于Peter Montgomery在2000年的网络帖子中所做的观察,即(A+B)/2 = (A AND B) + (A XOR B)/2,在中间计算中没有溢出。
这个特定的仿真代码(函数vsetles4)能否进一步加速?一般来说,任何基本操作次数更少的解决方案都将符合要求。我正在寻找ISO-C99的便携式解决方案,不使用机器特定的内部函数。大多数架构都支持ANDN(a & ~b),因此可以假定它作为效率的单个操作可用。
#include <stdint.h>

/*
   vsetles4 treats its inputs as arrays of bytes each of which comprises
   a signed integers in [-128,127]. Compute in byte-wise fashion, between
   corresponding bytes of 'a' and 'b', the boolean predicate "less than 
   or equal" as a value in [0,1] into the corresponding byte of the result.
*/

/* reference implementation */
uint32_t vsetles4_ref (uint32_t a, uint32_t b)
{
    uint8_t a0 = (uint8_t)((a >>  0) & 0xff);
    uint8_t a1 = (uint8_t)((a >>  8) & 0xff);
    uint8_t a2 = (uint8_t)((a >> 16) & 0xff);
    uint8_t a3 = (uint8_t)((a >> 24) & 0xff);
    uint8_t b0 = (uint8_t)((b >>  0) & 0xff);
    uint8_t b1 = (uint8_t)((b >>  8) & 0xff);
    uint8_t b2 = (uint8_t)((b >> 16) & 0xff);
    uint8_t b3 = (uint8_t)((b >> 24) & 0xff);
    int p0 = (int32_t)(int8_t)a0 <= (int32_t)(int8_t)b0;
    int p1 = (int32_t)(int8_t)a1 <= (int32_t)(int8_t)b1;
    int p2 = (int32_t)(int8_t)a2 <= (int32_t)(int8_t)b2;
    int p3 = (int32_t)(int8_t)a3 <= (int32_t)(int8_t)b3;

    return (((uint32_t)p3 << 24) | ((uint32_t)p2 << 16) |
            ((uint32_t)p1 <<  8) | ((uint32_t)p0 <<  0));
}

/* Optimized implementation:

   a <= b; a - b <= 0;  a + ~b + 1 <= 0; a + ~b < 0; (a + ~b)/2 < 0.
   Compute avg(a,~b) without overflow, rounding towards -INF; then
   lteq(a,b) = sign bit of result. In other words: compute 'lteq' as 
   (a & ~b) + arithmetic_right_shift (a ^ ~b, 1) giving the desired 
   predicate in the MSB of each byte.
*/
uint32_t vsetles4 (uint32_t a, uint32_t b)
{
    uint32_t m, s, t, nb;
    nb = ~b;            // ~b
    s = a & nb;         // a & ~b
    t = a ^ nb;         // a ^ ~b
    m = t & 0xfefefefe; // don't cross byte boundaries during shift
    m = m >> 1;         // logical portion of arithmetic right shift
    s = s + m;          // start (a & ~b) + arithmetic_right_shift (a ^ ~b, 1)
    s = s ^ t;          // complete arithmetic right shift and addition
    s = s & 0x80808080; // MSB of each byte now contains predicate
    t = s >> 7;         // result is byte-wise predicate in [0,1]
    return t;
}

1
你有实际计时过吗?相对于遍历一个 int8_t 数组并应用 <=,这种方法是否更快?(不是相对于 vsetles4_ref 计时 - 而是相对于根本不尝试将这些东西打包成 uint32_t。) - user2357112
1
这是一个超级有趣的问题。 - Matt Timmermans
2
@AldwinCheung 我最初是为 NVIDIA GPU 研究这个的。Kepler 架构在硬件上(大部分)具备此操作,而后续 GPU 架构则需要模拟。在 NVIDIA 期间,我创建了一组较大的原语(此处,采用 BSD 许可证)。虽然我已经退休,但 SO 上的一个最近问题 促使我重新访问这个模拟,这与架构无关(也可能对低端 ARM 有用)。 - njuffa
1
@QPaysTaxes 就我所知,优化问题在这里不属于离题讨论(这是我非常熟悉的一个网站),因此没有必要迁移到 CodeReview.SE(这是我完全不熟悉的一个网站)。此外,该问题被标记为“不清楚”,而不是“离题”,不清楚原因。我会保持现状,这是我的爱好之一,我没有迫切需要找到更好的解决方案,我只是对潜在的更优代码方式感到好奇。 - njuffa
1
我可能表达不清,因为我的问题不是“如何改进这段代码”,而是“如何用更少的基本操作实现这个功能”,这对我来说似乎非常具体。“更少的基本操作”足以加速我感兴趣的简单顺序处理器。欢迎提出重新表述问题的建议,以使其更加清晰明了。 - njuffa
显示剩余9条评论
1个回答

0
为了[可能地]节省你的一些工作并回答用户2357112的问题,我创建了一个[粗略的]基准测试。我将每次添加一个字节作为基本参考:
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <time.h>

long opt_R;
long opt_N;

void *aptr;
void *bptr;
void *cptr;

/*
   vsetles4 treats its inputs as arrays of bytes each of which comprises
   a signed integers in [-128,127]. Compute in byte-wise fashion, between
   corresponding bytes of 'a' and 'b', the boolean predicate "less than
   or equal" as a value in [0,1] into the corresponding byte of the result.
*/

/* base implementation */
void
vsetles4_base(const void *va, const void *vb, long count, void *vc)
{
    const char *aptr;
    const char *bptr;
    char *cptr;
    long idx;

    count *= 4;
    aptr = va;
    bptr = vb;
    cptr = vc;

    for (idx = 0;  idx < count;  ++idx)
        cptr[idx] = (aptr[idx] <= bptr[idx]);
}

/* reference implementation */
static inline uint32_t
_vsetles4_ref(uint32_t a, uint32_t b)
{
    uint8_t a0 = (uint8_t)((a >>  0) & 0xff);
    uint8_t a1 = (uint8_t)((a >>  8) & 0xff);
    uint8_t a2 = (uint8_t)((a >> 16) & 0xff);
    uint8_t a3 = (uint8_t)((a >> 24) & 0xff);
    uint8_t b0 = (uint8_t)((b >>  0) & 0xff);
    uint8_t b1 = (uint8_t)((b >>  8) & 0xff);
    uint8_t b2 = (uint8_t)((b >> 16) & 0xff);
    uint8_t b3 = (uint8_t)((b >> 24) & 0xff);

    int p0 = (int32_t)(int8_t)a0 <= (int32_t)(int8_t)b0;
    int p1 = (int32_t)(int8_t)a1 <= (int32_t)(int8_t)b1;
    int p2 = (int32_t)(int8_t)a2 <= (int32_t)(int8_t)b2;
    int p3 = (int32_t)(int8_t)a3 <= (int32_t)(int8_t)b3;

    return (((uint32_t)p3 << 24) | ((uint32_t)p2 << 16) |
            ((uint32_t)p1 <<  8) | ((uint32_t)p0 <<  0));
}

uint32_t
vsetles4_ref(uint32_t a, uint32_t b)
{

    return _vsetles4_ref(a,b);
}

/* Optimized implementation:
   a <= b; a - b <= 0;  a + ~b + 1 <= 0; a + ~b < 0; (a + ~b)/2 < 0.
   Compute avg(a,~b) without overflow, rounding towards -INF; then
   lteq(a,b) = sign bit of result. In other words: compute 'lteq' as
   (a & ~b) + arithmetic_right_shift (a ^ ~b, 1) giving the desired
   predicate in the MSB of each byte.
*/
static inline uint32_t
_vsetles4(uint32_t a, uint32_t b)
{
    uint32_t m, s, t, nb;
    nb = ~b;            // ~b
    s = a & nb;         // a & ~b
    t = a ^ nb;         // a ^ ~b
    m = t & 0xfefefefe; // don't cross byte boundaries during shift
    m = m >> 1;         // logical portion of arithmetic right shift
    s = s + m;          // start (a & ~b) + arithmetic_right_shift (a ^ ~b, 1)
    s = s ^ t;          // complete arithmetic right shift and addition
    s = s & 0x80808080; // MSB of each byte now contains predicate
    t = s >> 7;         // result is byte-wise predicate in [0,1]
    return t;
}

uint32_t
vsetles4(uint32_t a, uint32_t b)
{

    return _vsetles4(a,b);
}

/* Optimized implementation:
   a <= b; a - b <= 0;  a + ~b + 1 <= 0; a + ~b < 0; (a + ~b)/2 < 0.
   Compute avg(a,~b) without overflow, rounding towards -INF; then
   lteq(a,b) = sign bit of result. In other words: compute 'lteq' as
   (a & ~b) + arithmetic_right_shift (a ^ ~b, 1) giving the desired
   predicate in the MSB of each byte.
*/
static inline uint64_t
_vsetles8(uint64_t a, uint64_t b)
{
    uint64_t m, s, t, nb;
    nb = ~b;            // ~b
    s = a & nb;         // a & ~b
    t = a ^ nb;         // a ^ ~b
    m = t & 0xfefefefefefefefell; // don't cross byte boundaries during shift
    m = m >> 1;         // logical portion of arithmetic right shift
    s = s + m;          // start (a & ~b) + arithmetic_right_shift (a ^ ~b, 1)
    s = s ^ t;          // complete arithmetic right shift and addition
    s = s & 0x8080808080808080ll; // MSB of each byte now contains predicate
    t = s >> 7;         // result is byte-wise predicate in [0,1]
    return t;
}

uint32_t
vsetles8(uint64_t a, uint64_t b)
{

    return _vsetles8(a,b);
}

void
aryref(const void *va,const void *vb,long count,void *vc)
{
    long idx;
    const uint32_t *aptr;
    const uint32_t *bptr;
    uint32_t *cptr;

    aptr = va;
    bptr = vb;
    cptr = vc;

    for (idx = 0;  idx < count;  ++idx)
        cptr[idx] = _vsetles4_ref(aptr[idx],bptr[idx]);
}

void
arybest4(const void *va,const void *vb,long count,void *vc)
{
    long idx;
    const uint32_t *aptr;
    const uint32_t *bptr;
    uint32_t *cptr;

    aptr = va;
    bptr = vb;
    cptr = vc;

    for (idx = 0;  idx < count;  ++idx)
        cptr[idx] = _vsetles4(aptr[idx],bptr[idx]);
}

void
arybest8(const void *va,const void *vb,long count,void *vc)
{
    long idx;
    const uint64_t *aptr;
    const uint64_t *bptr;
    uint64_t *cptr;

    count >>= 1;

    aptr = va;
    bptr = vb;
    cptr = vc;

    for (idx = 0;  idx < count;  ++idx)
        cptr[idx] = _vsetles8(aptr[idx],bptr[idx]);
}

double
tvgetf(void)
{
    struct timespec ts;
    double sec;

    clock_gettime(CLOCK_REALTIME,&ts);
    sec = ts.tv_nsec;
    sec /= 1e9;
    sec += ts.tv_sec;

    return sec;
}

void
timeit(void (*fnc)(const void *,const void *,long,void *),const char *sym)
{
    double tvbeg;
    double tvend;

    tvbeg = tvgetf();
    fnc(aptr,bptr,opt_N,cptr);
    tvend = tvgetf();

    printf("timeit: %.9f %s\n",tvend - tvbeg,sym);
}

// fill -- fill array with random numbers
void
fill(void *vptr)
{
    uint32_t *iptr = vptr;

    for (long idx = 0;  idx < opt_N;  ++idx)
        iptr[idx] = rand();
}

// main -- main program
int
main(int argc,char **argv)
{
    char *cp;

    --argc;
    ++argv;

    for (;  argc > 0;  --argc, ++argv) {
        cp = *argv;
        if (*cp != '-')
            break;

        switch (cp[1]) {
        case 'R':
            opt_R = strtol(cp + 2,&cp,10);
            break;

        case 'N':
            opt_N = strtol(cp + 2,&cp,10);
            break;

        default:
            break;
        }
    }

    if (opt_R == 0)
        opt_R = 1;
    srand(opt_R);
    printf("R=%ld\n",opt_R);

    if (opt_N == 0)
        opt_N = 100000000;
    printf("N=%ld\n",opt_N);

    aptr = calloc(opt_N,sizeof(uint32_t));
    bptr = calloc(opt_N,sizeof(uint32_t));
    cptr = calloc(opt_N,sizeof(uint32_t));

    fill(aptr);
    fill(bptr);

    timeit(vsetles4_base,"base");
    timeit(aryref,"aryref");
    timeit(arybest4,"arybest4");
    timeit(arybest8,"arybest8");
    timeit(vsetles4_base,"base");

    return 0;
}

这是一次运行的输出结果:
R=1
N=100000000
timeit: 0.550527096 base
timeit: 0.483014107 aryref
timeit: 0.236460924 arybest4
timeit: 0.147254944 arybest8
timeit: 0.440311432 base

请注意,您的引用比逐字节处理快得不多[在我看来,这几乎不值得复杂性]。
您优化的算法确实提供了最佳性能,除了SIMD之外,并且我将其扩展为使用uint64_t,这样可以再次将速度加倍[naturally]。
对您来说,测试SIMD版本也可能是有趣的。只是为了证明它们确实是最快的。

感谢您付出的所有努力。目前我没有访问支持硬件SIMD(特定GPU架构)的硬件平台,但基于我对该架构的了解,那个硬件应该比我这里的vsetles4仿真代码快大约两倍(SIMD指令的吞吐量是常规ALU指令的1/4,执行顺序)。我的原始代码(BSD许可证)可以在此处找到。自从发布软件以来,接口已经固定,完全如图所示。 - njuffa

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接