为什么在这段 Rust 代码中没有分支预测失败惩罚?

7

我编写了这个非常简单的 Rust 函数:

fn iterate(nums: &Box<[i32]>) -> i32 {
    let mut total = 0;
    let len = nums.len();
    for i in 0..len {
        if nums[i] > 0 {
            total += nums[i];
        } else {
            total -= nums[i];
        }
    }

    total
}

我已经写了一个基本的基准测试,用有序数组和乱序数组调用该方法:

fn criterion_benchmark(c: &mut Criterion) {
    const SIZE: i32 = 1024 * 1024;

    let mut group = c.benchmark_group("Branch Prediction");

    // setup benchmarking for an ordered array
    let mut ordered_nums: Vec<i32> = vec![];
    for i in 0..SIZE {
        ordered_nums.push(i - SIZE/2);
    }
    let ordered_nums = ordered_nums.into_boxed_slice();
    group.bench_function("ordered", |b| b.iter(|| iterate(&ordered_nums)));

    // setup benchmarking for a shuffled array
    let mut shuffled_nums: Vec<i32> = vec![];
    for i in 0..SIZE {
        shuffled_nums.push(i - SIZE/2);
    }
    let mut rng = thread_rng();
    let mut shuffled_nums = shuffled_nums.into_boxed_slice();
    shuffled_nums.shuffle(&mut rng);
    group.bench_function("shuffled", |b| b.iter(|| iterate(&shuffled_nums)));

    group.finish();
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

我很惊讶这两个基准测试的运行时间几乎一样,而在Java中类似的基准测试在两种情况下表现出截然不同,可能是由于洗牌案例中分支预测失败所致。

我看到有提到条件移动指令,但如果我在可执行文件上运行 otool -tv(我正在Mac上运行),我在iterate方法输出中没有看到任何指令。

是否有人能够说明为什么Rust中有序和无序情况之间没有明显的性能差异呢?


4
我怀疑这与 Rust/LLVM 如何将这些循环优化为 SIMD 指令有关(我相信 Java 无法做到这一点)。 - Frxstrem
2
@Frxstrem,是的,在我的电脑上它使用AVX ISA,即使在Rust Playground中,它也通过使用“条件移动小于”指令cmovll来展开逻辑。 - sshashank124
1
@sshashank124:是的,启用完整优化(-O3)后,像LLVM和GCC这样的现代预编译编译器后端通常会将分支转换为CMOV或其他无分支序列。这也是自动向量化的先决条件。 - Peter Cordes
1个回答

11
摘要: LLVM能够通过使用cmov指令或SIMD指令的巧妙组合来移除/隐藏分支。
我使用Godbolt(加上-C opt-level=3)来查看完整汇编代码。下面我将解释汇编代码的重要部分。
它的开头是这样的:
        mov     r9, qword ptr [rdi + 8]         ; r9 = nums.len()
        test    r9, r9                          ; if len == 0
        je      .LBB0_1                         ;     goto LBB0_1
        mov     rdx, qword ptr [rdi]            ; rdx = base pointer (first element)
        cmp     r9, 7                           ; if len > 7
        ja      .LBB0_5                         ;     goto LBB0_5
        xor     eax, eax                        ; eax = 0
        xor     esi, esi                        ; esi = 0
        jmp     .LBB0_4                         ; goto LBB0_4

.LBB0_1:
        xor     eax, eax                        ; return 0
        ret

在这里,该函数区分了3种不同的“状态”:

  • 切片为空→立即返回0
  • 切片长度≤7→使用标准顺序算法(LBB0_4
  • 切片长度>7→使用SIMD算法(LBB0_5

那么让我们来看一下两种不同类型的算法!


标准顺序算法

记住,rsi (esi) 和 rax (eax) 被设置为0,而 rdx 是数据的基指针。

.LBB0_4:
        mov     ecx, dword ptr [rdx + 4*rsi]    ; ecx = nums[rsi]
        add     rsi, 1                          ; rsi += 1
        mov     edi, ecx                        ; edi = ecx
        neg     edi                             ; edi = -edi
        cmovl   edi, ecx                        ; if ecx >= 0 { edi = ecx }
        add     eax, edi                        ; eax += edi
        cmp     r9, rsi                         ; if rsi != len
        jne     .LBB0_4                         ;     goto LBB0_4
        ret                                     ; return eax

这是一个简单的循环,遍历所有num元素。在循环体中有一个小技巧:从原始元素ecx中,将取反值存储在edi中。通过使用cmovl,如果原始值为正,则edi被覆盖为原始值。这意味着edi总是为正数(即包含原始元素的绝对值)。然后将其添加到eax中(最终返回)。因此,您的if分支隐藏在cmov指令中。如此基准测试所示,执行cmov指令所需的时间与条件的概率无关。它是一条非常惊人的指令!


SIMD算法

SIMD版本由许多指令组成,这里不会完全粘贴。主循环每次可以处理16个整数!

        movdqu  xmm5, xmmword ptr [rdx + 4*rdi]
        movdqu  xmm3, xmmword ptr [rdx + 4*rdi + 16]
        movdqu  xmm0, xmmword ptr [rdx + 4*rdi + 32]
        movdqu  xmm1, xmmword ptr [rdx + 4*rdi + 48]

它们从内存加载到寄存器xmm0、xmm1、xmm3和xmm5中。每个寄存器包含四个32位值,但为了更容易理解,请想象每个寄存器仅包含一个值。所有后续指令都会单独操作这些SIMD寄存器的每个值,因此这种心理模型是可行的!我的以下解释也会让人觉得xmm寄存器仅包含单个值。
主要技巧现在在以下指令中(处理xmm5):
        movdqa  xmm6, xmm5      ; xmm6 = xmm5 (make a copy)
        psrad   xmm6, 31        ; logical right shift 31 bits (see below)
        paddd   xmm5, xmm6      ; xmm5 += xmm6
        pxor    xmm5, xmm6      ; xmm5 ^= xmm6

逻辑右移会使用符号位填充“空的高位”(即左侧“移入”的位)。通过将其向右移动31位,我们最终得到每个位置上只有符号位!因此,任何正数都将变成32个零,而任何负数都将变成32个一。因此,现在的xmm6要么是000...000(如果xmm5为正),要么是111...111(如果xmm5为负)。
接下来,将这个人造的xmm6加到xmm5上。如果xmm5是正的,则xmm6为0,因此添加它不会改变xmm5。但是,如果xmm5为负,则我们添加111...111,相当于减去1。最后,我们对xmm5和xmm6进行异或运算。同样,如果xmm5一开始是正的,则我们与000...000进行异或,这没有影响。如果xmm5一开始是负的,则我们与111...111进行异或,这意味着翻转所有位。因此,对于两种情况:
  • 如果这个元素是正数,我们什么也不做(addxor没有任何影响)
  • 如果这个元素是负数,我们减去1并翻转所有位。 这是二进制补码的取反!

因此,通过这4个指令,我们计算出了xmm5的绝对值!由于这种位操作技巧,所以没有分支。请记住,xmm5实际上包含4个整数,因此速度非常快!

现在,将这个绝对值添加到累加器中,并对其他三个包含切片值的xmm寄存器执行相同的操作。(我们不会详细讨论剩余的代码。)


SIMD with AVX2

如果我们允许LLVM发出AVX2指令(通过-C target-feature=+avx2),它甚至可以使用pabsd指令来代替四个“hacky”指令:

vpabsd  ymm2, ymmword ptr [rdx + 4*rdi]

这个指令直接从内存中加载数值,计算绝对值并在一条指令中将其存储在ymm2寄存器中!请记住,ymm寄存器的大小是xmm寄存器的两倍(容纳八个32位的值)!


1
你可能想告诉LLVM不要展开循环,这样你就可以看到它的操作而不会陷入展开中。对于clang,选项是-fno-unroll-loops,但该选项名称可能只是为了与GCC兼容,而不是LLVM自己的内部名称。此外,如果您让它使用SSSE3或AVX2,它将有望使用pabsd来执行SIMD绝对值运算,而无需使用二进制补码恒等式-x = ~(x - 1) - Peter Cordes
@PeterCordes 谢谢!我添加了一些关于 pabsd 的信息。使用 AVX2 汇编确实更加美观。 - Lukas Kalbertodt
太糟糕了,即使展开循环,LLVM仍然使用索引寻址模式,因此指令在Intel CPU上的成本为2个融合域uop。:/ 微融合和寻址模式。虽然数据在L1d缓存中很热,但它可能不会在前端成为瓶颈,只需vpabsd [mem] + vpaddd总共是Haswell / Skylake上的3个融合域uops。 (而且管道宽度为4,因此有空间用于循环开销。) - Peter Cordes
谢谢你的好回答!我从中学到了很多,并感谢你提供Godbolt的链接!使用它作为repl,我能够通过只是使if语句体变得更加复杂(例如将total += nums[i]转换为类似于total += nums[i]*(nums[i]-1)的更大的东西)来消除条件移动(并展示我正在寻找的分支预测失败惩罚)。我猜这在技术上仍然可以使用条件移动来完成,但优化器的启发式决定分支是更好的方法? - Dathan

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接