高效的整数比较函数

65
compare函数是一个接受两个参数ab,并返回描述它们顺序的整数的函数。如果a小于b,则结果为负整数。如果a大于b,则结果为正整数。否则,ab相等,结果为零。
这个函数常用于标准库中的排序和搜索算法的参数化。
实现字符比较的compare函数非常容易;您只需对参数进行减法运算即可。
int compare_char(char a, char b)
{
    return a - b;
}

这种方法之所以有效,是因为通常假定两个字符之间的差值可以装进一个整数中。(请注意,在 sizeof(char) == sizeof(int) 的系统中,这个假设不成立。)
对于比较整数,这种技巧是行不通的,因为两个整数之间的差值通常无法容纳在一个整数中。例如,INT_MAX - (-1) = INT_MIN 表明 INT_MAX-1 小(严格来说,溢出会导致未定义的行为,但我们假设使用模算术)。
那么如何高效地实现整数的比较函数呢?以下是我的第一次尝试:
int compare_int(int a, int b)
{
    int temp;
    int result;
    __asm__ __volatile__ (
        "cmp %3, %2 \n\t"
        "mov $0, %1 \n\t"

        "mov $1, %0 \n\t"
        "cmovg %0, %1 \n\t"

        "mov $-1, %0 \n\t"
        "cmovl %0, %1 \n\t"
    : "=r"(temp), "=r"(result)
    : "r"(a), "r"(b)
    : "cc");
    return result;
}

能否用不到6条指令完成?是否有更高效但不那么直接的方法?


28
你拆解过 return (a<b)?-1:(a>b); 这段代码吗? - jxh
5
只有五个指令,太好了!现在我感觉自己很蠢 :) - fredoverflow
28
这几乎可以被视为无意义的优化。相较于调用开销和调用者中算法成本,这里所节省的成本微不足道。 - Raymond Chen
19
如果这是内联的话,那么情况会更糟糕,因为编译器无法优化内联汇编。最好用C语言表达,这样优化器可以将其优化成内联函数。例如,你的内联汇编阻止了编译器将 if (compare_int(a,b) < 0) 优化为 if (a < b) - Raymond Chen
30
雷蒙德·陈(Raymond Chen)刚刚发布了一篇有趣的文章,探讨了整数比较的基准测试以及优化器能够实现的内容。 - isanae
显示剩余16条评论
7个回答

97

这个没有分支,也不会出现溢出或下溢的问题:

return (a > b) - (a < b);

使用 gcc -O2 -S 编译,这将编译成以下六条指令:
xorl    %eax, %eax
cmpl    %esi, %edi
setl    %dl
setg    %al
movzbl  %dl, %edx
subl    %edx, %eax

以下是用于基准测试不同比较实现的一些代码:

#include <stdio.h>
#include <stdlib.h>

#define COUNT 1024
#define LOOPS 500
#define COMPARE compare2
#define USE_RAND 1

int arr[COUNT];

int compare1 (int a, int b)
{
    if (a < b) return -1;
    if (a > b) return 1;
    return 0;
}

int compare2 (int a, int b)
{
    return (a > b) - (a < b);
}

int compare3 (int a, int b)
{
    return (a < b) ? -1 : (a > b);
}

int compare4 (int a, int b)
{
    __asm__ __volatile__ (
        "sub %1, %0 \n\t"
        "jno 1f \n\t"
        "cmc \n\t"
        "rcr %0 \n\t"
        "1: "
    : "+r"(a)
    : "r"(b)
    : "cc");
    return a;
}

int main ()
{
    for (int i = 0; i < COUNT; i++) {
#if USE_RAND
        arr[i] = rand();
#else
        for (int b = 0; b < sizeof(arr[i]); b++) {
            *((unsigned char *)&arr[i] + b) = rand();
        }
#endif
    }

    int sum = 0;

    for (int l = 0; l < LOOPS; l++) {
        for (int i = 0; i < COUNT; i++) {
            for (int j = 0; j < COUNT; j++) {
                sum += COMPARE(arr[i], arr[j]);
            }
        }
    }

    printf("%d=0\n", sum);

    return 0;
}

在我的64位系统上,使用gcc -std=c99 -O2编译,针对正整数进行测试(USE_RAND=1),结果如下:

compare1: 0m1.118s
compare2: 0m0.756s
compare3: 0m1.101s
compare4: 0m0.561s

在仅使用C语言的解决方案中,我提出的那个是最快的。用户315052的解决方案尽管只编译为5条指令,但速度较慢。这种减速很可能是因为尽管有一条更少的指令,但有一个条件指令 (cmovge)。

总的来说,FredOverflow的4条指令的汇编实现在使用正整数时是最快的。然而,该代码仅对整数范围RAND_MAX进行基准测试,因此4条指令的测试具有偏见,因为它单独处理溢出,并且在测试中不会发生这些情况;速度可能是由于成功的分支预测导致的。

对于全部整数范围 (USE_RAND=0),实际上4条指令的解决方案非常慢(其他方案相同):

compare4: 0m1.897s

你能修改基准测试以生成完整的整数谱吗?如果你用 rand() | rand() << 17 替换 rand() 会发生什么? - fredoverflow
@AmbrozBizjak:你验证了编译器是否生成了我在答案中发布的相同的5条指令吗?我只是好奇。 - jxh
@jxh 在我的当前系统上,我得到了:compare2 0m0.488s,compare3 0m0.502s。汇编代码与发布的不同,但类似:http://ideone.com/uhmIYt 和 http://ideone.com/cD9Zd2。这些基准测试的问题在于结果还取决于编译器如何将代码内联到循环中。如果我们强制它不内联比较函数,可能会更好。 - Ambroz Bizjak
注意,在gcc 7.x中(但不是6.x或8.x),compare3(...) { a < b ? -1 : (a > b); }使用了分支-所以如果您的比较是不可预测的,您可能会遇到令人不愉快的意外! - BeeOnRope

56
以下方法对我来说一直非常有效:
return (a < b) ? -1 : (a > b);

使用gcc -O2 -S编译时,代码会被编译成以下五条指令:
xorl    %edx, %edx
cmpl    %esi, %edi
movl    $-1, %eax
setg    %dl
cmovge  %edx, %eax

作为Ambroz Bizjak优秀的配套答案的跟进,我并不确定他的程序是否测试了上述发布的相同汇编代码。当我更仔细地研究编译器输出时,我注意到编译器没有生成与我们任何一个回答中发布的相同指令。所以,我拿他的测试程序,手动修改汇编输出以匹配我们发布的内容,并比较了结果时间。两个版本看起来大致相同。
./opt_cmp_branchless: 0m1.070s
./opt_cmp_branch:     0m1.037s

我会尽力帮忙翻译。这段文字的意思是作者会公开每个程序的完整汇编代码,以便其他人可以尝试相同的实验,并确认或反驳他的观察结果。下面是使用 "cmovge" 指令的版本("(a < b) ? -1 : (a > b)")。
        .file   "cmp.c"
        .text
        .section        .rodata.str1.1,"aMS",@progbits,1
.LC0:
        .string "%d=0\n"
        .text
        .p2align 4,,15
.globl main
        .type   main, @function
main:
.LFB20:
        .cfi_startproc
        pushq   %rbp
        .cfi_def_cfa_offset 16
        .cfi_offset 6, -16
        pushq   %rbx
        .cfi_def_cfa_offset 24
        .cfi_offset 3, -24
        movl    $arr.2789, %ebx
        subq    $8, %rsp
        .cfi_def_cfa_offset 32
.L9:
        leaq    4(%rbx), %rbp
.L10:
        call    rand
        movb    %al, (%rbx)
        addq    $1, %rbx
        cmpq    %rbx, %rbp
        jne     .L10
        cmpq    $arr.2789+4096, %rbp
        jne     .L9
        xorl    %r8d, %r8d
        xorl    %esi, %esi
        orl     $-1, %edi
.L12:
        xorl    %ebp, %ebp
        .p2align 4,,10
        .p2align 3
.L18:
        movl    arr.2789(%rbp), %ecx
        xorl    %eax, %eax
        .p2align 4,,10
        .p2align 3
.L15:
        movl    arr.2789(%rax), %edx
        xorl    %ebx, %ebx
        cmpl    %ecx, %edx
        movl    $-1, %edx
        setg    %bl
        cmovge  %ebx, %edx
        addq    $4, %rax
        addl    %edx, %esi
        cmpq    $4096, %rax
        jne     .L15
        addq    $4, %rbp
        cmpq    $4096, %rbp
        jne     .L18
        addl    $1, %r8d
        cmpl    $500, %r8d
        jne     .L12
        movl    $.LC0, %edi
        xorl    %eax, %eax
        call    printf
        addq    $8, %rsp
        .cfi_def_cfa_offset 24
        xorl    %eax, %eax
        popq    %rbx
        .cfi_def_cfa_offset 16
        popq    %rbp
        .cfi_def_cfa_offset 8
        ret
        .cfi_endproc
.LFE20:
        .size   main, .-main
        .local  arr.2789
        .comm   arr.2789,4096,32
        .section        .note.GNU-stack,"",@progbits

下面的版本使用无分支方法 ((a > b) - (a < b)):
        .file   "cmp.c"
        .text
        .section        .rodata.str1.1,"aMS",@progbits,1
.LC0:
        .string "%d=0\n"
        .text
        .p2align 4,,15
.globl main
        .type   main, @function
main:
.LFB20:
        .cfi_startproc
        pushq   %rbp
        .cfi_def_cfa_offset 16
        .cfi_offset 6, -16
        pushq   %rbx
        .cfi_def_cfa_offset 24
        .cfi_offset 3, -24
        movl    $arr.2789, %ebx
        subq    $8, %rsp
        .cfi_def_cfa_offset 32
.L9:
        leaq    4(%rbx), %rbp
.L10:
        call    rand
        movb    %al, (%rbx)
        addq    $1, %rbx
        cmpq    %rbx, %rbp
        jne     .L10
        cmpq    $arr.2789+4096, %rbp
        jne     .L9
        xorl    %r8d, %r8d
        xorl    %esi, %esi
.L19:
        movl    %ebp, %ebx
        xorl    %edi, %edi
        .p2align 4,,10
        .p2align 3
.L24:
        movl    %ebp, %ecx
        xorl    %eax, %eax
        jmp     .L22
        .p2align 4,,10
        .p2align 3
.L20:
        movl    arr.2789(%rax), %ecx
.L22:
        xorl    %edx, %edx
        cmpl    %ebx, %ecx
        setg    %cl
        setl    %dl
        movzbl  %cl, %ecx
        subl    %ecx, %edx
        addl    %edx, %esi
        addq    $4, %rax
        cmpq    $4096, %rax
        jne     .L20
        addq    $4, %rdi
        cmpq    $4096, %rdi
        je      .L21
        movl    arr.2789(%rdi), %ebx
        jmp     .L24
.L21:
        addl    $1, %r8d
        cmpl    $500, %r8d
        jne     .L19
        movl    $.LC0, %edi
        xorl    %eax, %eax
        call    printf
        addq    $8, %rsp
        .cfi_def_cfa_offset 24
        xorl    %eax, %eax
        popq    %rbx
        .cfi_def_cfa_offset 16
        popq    %rbp
        .cfi_def_cfa_offset 8
        ret
        .cfi_endproc
.LFE20:
        .size   main, .-main
        .local  arr.2789
        .comm   arr.2789,4096,32
        .section        .note.GNU-stack,"",@progbits

3
附加的好处是它不依赖于 asm 关键字,如你所知,在一些平台上这样做行不通。+1 - John Dibling
1
这只能在i686和后续版本的x86 CPU系列上运行,例如那些具有CMOV指令的CPU。实际上,它适用于Pentium Pro之后生产的任何CPU(除了Pentium MMX),所以您应该没有问题,除非您的用户恰好拥有超过10年历史的机器。 - Daniel Kamil Kozar
3
如果你只使用 C 版本,那么假设编译器还算不错的话,你应该会得到几乎最优化的代码。 - Paul R
当然,你是对的。我只是指出了来自这段特定代码的事实。 - Daniel Kamil Kozar
@FredOverflow:cmovge 真的会导致流水线停顿吗?你展示的汇编输出中似乎没有条件跳转。你使用的编译器版本是什么? - jxh

16

好的,我成功将其简化为四条指令 :) 基本思路如下:

一半情况下,差异小到可以放进整数中。在这种情况下,只需返回差异即可。否则,把数字向右移动一位。关键问题是要将哪个位移入MSB中。

为了简单起见,让我们使用8位来看两个极端的例子:

 10000000 INT_MIN
 01111111 INT_MAX
---------
000000001 difference
 00000000 shifted

 01111111 INT_MAX
 10000000 INT_MIN
---------
111111111 difference
 11111111 shifted

将进位位移入会导致第一种情况得到0(尽管INT_MIN不等于INT_MAX),第二种情况得到一些负数(尽管INT_MAX不小于INT_MIN)。但是,如果我们在进行移位之前翻转进位位,我们将得到有意义的数字:
 10000000 INT_MIN
 01111111 INT_MAX
---------
000000001 difference
100000001 carry flipped
 10000000 shifted

 01111111 INT_MAX
 10000000 INT_MIN
---------
111111111 difference
011111111 carry flipped
 01111111 shifted

我确定有一个深刻的数学原因解释为什么翻转进位位是有意义的,但我还没看出来。

int compare_int(int a, int b)
{
    __asm__ __volatile__ (
        "sub %1, %0 \n\t"
        "jno 1f \n\t"
        "cmc \n\t"
        "rcr %0 \n\t"
        "1: "
    : "+r"(a)
    : "r"(b)
    : "cc");
    return a;
}

我已经使用一百万个随机输入和每种组合的INT_MIN,-INT_MAX,INT_MIN/2,-1,0,1,INT_MAX/2,INT_MAX来测试代码。所有测试都通过了。你能证明我错了吗?


即使只有4条指令,你可能仍然会发现5或6条无分支序列更快。 - Paul R
1
@mfontanini 我预计我的解决方案在随机输入的情况下会是最慢的。另一方面,如果溢出很少发生,我会期望它表现得非常好(分支预测)。无论如何,我只是觉得很有挑战性,想找到一个更短的解决方案 :) - fredoverflow
1
如果你为一个一位全减法器写出真值表,你会发现反转进位确实等于有符号溢出的正确符号。在此过程中,请记住,在计算正确符号和有符号溢出时,“被减数位=1”和“减数位=1”应该在算术上视为“-1”,而且只有当和为“0”或“-1”时才不会溢出。因此,你的猜测是正确的。 - Alexey Frunze

10

就我所知,我编写了一个SSE2实现。 vec_compare1使用与compare2相同的方法,但只需要三个SSE2算术指令:

#include <stdio.h>
#include <stdlib.h>
#include <emmintrin.h>

#define COUNT 1024
#define LOOPS 500
#define COMPARE vec_compare1
#define USE_RAND 1

int arr[COUNT] __attribute__ ((aligned(16)));

typedef __m128i vSInt32;

vSInt32 vec_compare1 (vSInt32 va, vSInt32 vb)
{
    vSInt32 vcmp1 = _mm_cmpgt_epi32(va, vb);
    vSInt32 vcmp2 = _mm_cmpgt_epi32(vb, va);
    return _mm_sub_epi32(vcmp2, vcmp1);
}

int main ()
{
    for (int i = 0; i < COUNT; i++) {
#if USE_RAND
        arr[i] = rand();
#else
        for (int b = 0; b < sizeof(arr[i]); b++) {
            *((unsigned char *)&arr[i] + b) = rand();
        }
#endif
    }

    vSInt32 vsum = _mm_set1_epi32(0);

    for (int l = 0; l < LOOPS; l++) {
        for (int i = 0; i < COUNT; i++) {
            for (int j = 0; j < COUNT; j+=4) {
                vSInt32 v1 = _mm_loadu_si128(&arr[i]);
                vSInt32 v2 = _mm_load_si128(&arr[j]);
                vSInt32 v = COMPARE(v1, v2);
                vsum = _mm_add_epi32(vsum, v);
            }
        }
    }

    printf("vsum = %vd\n", vsum);

    return 0;
}

这需要0.137秒的时间。

同样的CPU和编译器比较compare2所需时间为0.674秒。

因此,SSE2实现大约快了4倍,这是可以预期的(因为它是4宽SIMD)。


3

这段代码没有分支,只使用了5个指令。在最近的英特尔处理器上,可能比其他无分支的替代方案表现更好,因为cmov*指令非常昂贵。缺点是返回值不对称(INT_MIN+1、0、1)。

int compare_int (int a, int b)
{
    int res;

    __asm__ __volatile__ (
        "xor %0, %0 \n\t"
        "cmpl %2, %1 \n\t"
        "setl %b0 \n\t"
        "rorl $1, %0 \n\t"
        "setnz %b0 \n\t"
    : "=q"(res)
    : "r"(a)
    , "r"(b)
    : "cc"
    );

    return res;
}

这个变量不需要初始化,因此它只使用4条指令:

int compare_int (int a, int b)
{
    __asm__ __volatile__ (
        "subl %1, %0 \n\t"
        "setl %b0 \n\t"
        "rorl $1, %0 \n\t"
        "setnz %b0 \n\t"
    : "+q"(a)
    : "r"(b)
    : "cc"
    );

    return a;
}

你提出的解决方案似乎有缺陷,请查看我的编辑。 - fredoverflow
@FredOverflow:没错,'res'的初始化缺失了。 - Evgeny Kluev

0

也许你可以使用以下思路(伪代码形式;没有写汇编代码,因为我对语法不太熟悉):

  1. 相减得到结果 (result = a - b)
  2. 如果没有溢出,直接结束(jo 指令和分支预测在这里应该能很好地工作)
  3. 如果有溢出,使用任何可靠的方法(return (a < b) ? -1 : (a > b)

编辑:为了更简单:如果有溢出,反转结果的符号,而不是执行步骤 3。


你的备选步骤3对于 compare_int(0, -2147483648) 不起作用。 - fredoverflow
同意,这是一个糟糕的优化。 - anatolyg

-2
你可以考虑将整数提升为64位值。

1
可以这样做,但是由于你必须返回一个整数,你仍然需要处理任何溢出。 - Paul R
@PaulR:只需屏蔽掉额外的位即可。 - Puppy
3
不行 - 需要更多 - 考虑a-b无法用32位表示的情况 - 如果只屏蔽高位比特,则结果符号将不正确。在转换回32位之前,您需要例如饱和到64位。这需要超过5个指令。 - Paul R
1
@DeadMG:但这些将会错误地将-1或+1舍入为零。我的想法是沿着x=x | x>>32的线路进行逻辑移位。这样可以将符号位移动到正确的位置,并确保只有0LL最终变成零。 它唯一的一个缺陷是将0x80000000LL视为负值。 - MSalters
@PaulR 当你说“饱和”时,我立刻感到鸡皮疙瘩,因为MMX已经饱和了整数运算!然而事实证明,只有8位和16位整数才是如此:( - fredoverflow
显示剩余8条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接