在尝试达到最佳编译器输出的过程中,移除Rust循环中的边界检查

8
为了确定我是否可以/应该使用 Rust 而不是默认的 C/C++,我正在研究各种边缘情况,主要是考虑以下问题:在 0.1% 的情况下,当它确实很重要时,我是否总是能够获得与 gcc 相同的编译器输出(使用适当的优化标志)?答案很可能是否定的,但让我们看看... Reddit 上有一个相当奇特的例子,研究了无分支排序算法的子程序的编译器输出。
这是基准 C 代码:
#include <stdint.h>
#include <stdlib.h>
int32_t* foo(int32_t* elements, int32_t* buffer, int32_t pivot)
{
    size_t buffer_index = 0;

    for (size_t i = 0; i < 64; ++i) {
        buffer[buffer_index] = (int32_t)i;
        buffer_index += (size_t)(elements[i] < pivot);
    }
}

以下是与编译器输出相关的godbolt链接

Rust的第一次尝试如下所示:

pub fn foo0(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    for i in 0..buffer.len() {
        buffer[buffer_index] = i as i32;
        buffer_index += (elements[i] < pivot) as usize; 
    }
}

这里有很多边界检查,详见godbolt

下一步尝试消除第一个边界检查:

pub unsafe fn foo1(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    for i in 0..buffer.len() {
        unsafe {
            buffer[buffer_index] = i as i32;
            buffer_index += (elements.get_unchecked(i) < &pivot) as usize; 
        }
    }
}

这有所改善了(请看与前面相同的godbolt链接)。

最后,让我们尝试完全删除边界检查:

use std::ptr;

pub unsafe fn foo2(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    unsafe {
        for i in 0..buffer.len() {
            ptr::replace(&mut buffer[buffer_index], i as i32);
            buffer_index += (elements.get_unchecked(i) < &pivot) as usize; 
        }
    }
}

这会产生与foo1相同的输出,因此ptr::replace仍会执行边界检查。在这里,使用那些unsafe操作是超出我的能力范围的。这引出了我的两个问题:

  • 如何消除边界检查?
  • 分析这种边缘情况甚至有意义吗?还是说如果呈现整个算法而不仅仅是其中一小部分,Rust编译器会看透这一点。

关于最后一点,我很好奇,通常是否可以将 Rust 精简到与 C 一样“直接”,即更接近底层。经验丰富的 Rust 程序员可能会对这种调查方式感到不适,但它确实存在...


4
请注意,您可以在不使用不安全代码的情况下通过 for (i, elt) in elements.iter().enumerate().take(buffer.len()) 来取消对 elements 的边界检查,然后使用 elt 代替 elements[i]playground)。 - Jmb
2个回答

4
  • 如何消除边界检查?

数组通过将其deref强制转换为切片,也具有未经检查的可变获取形式 get_unchecked_mut

pub unsafe fn foo(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) {
    let mut buffer_index: usize = 0;
    for i in 0..buffer.len() {
        unsafe {
            *buffer.get_unchecked_mut(buffer_index) = i as i32;
            buffer_index += (elements.get_unchecked(i) < &pivot) as usize; 
        }
    }
}

这可能产生与使用Clang编译等效C代码时获得的机器代码相同的结果。https://godbolt.org/z/ddxP1P

  • 分析这样的边缘情况是否有意义?或者如果提供整个算法而不仅仅是一小部分,Rust编译器是否能够看穿所有这些情况。

像往常一样,在你确定代码的那一部分存在瓶颈的情况下,应对任何这些情况进行基准测试。否则,这是一种过早进行的优化,有一天可能会后悔。特别是在Rust中,编写unsafe代码的决定不应该轻率地做出。可以安全地说,在许多情况下,仅仅为了删除边界检查所付出的努力和风险就超过了预期的性能收益。

关于最后一点,我好奇的是,一般而言,Rust是否可以被删减到与C“文字级”即接近硬件的程度。

不行,并且你也不希望这样做,主要有两个原因:

  1. 尽管Rust的抽象能力很强,但不支付你不使用的东西的原则仍然非常相关,类似于C++。请参见什么使得抽象成本为零。在边界检查的情况下,这仅是一种语言设计决策的结果,当编译器无法确保这样的访问是内存安全时,会始终执行空间检查。
  2. C 也不是那么低级。 它可能看起来文字级且接近硬件,直到它真的不是了。

另请参见:


我绝对不想暗示,或让任何读者暗示,Rust比语言X更慢/快/好/差。问题特别关注的是那0.1%的情况,可能需要进行大量优化。当然,你是正确的,在几乎所有情况下,这样的优化绝对不值得努力。 - mcmayer

2
你可以使用老派的指针算术来实现这一点。
const N: usize = 64;
pub fn foo2(elements: &Vec<i32>, mut buffer: [i32; N], pivot: i32) -> () {
    assert!(elements.len() >= N);
    let elements = &elements[..N];
    let mut buff_ptr = buffer.as_mut_ptr();
    for (i, &elem) in elements.iter().enumerate(){
        unsafe{
            // SAFETY: We increase ptr strictly less or N times
            *buff_ptr = i as i32;
            if elem < pivot{
                buff_ptr = buff_ptr.add(1);
            }
        }
    }
}

这个版本编译成:

example::foo2:
        push    rax
        cmp     qword ptr [rdi + 16], 64
        jb      .LBB7_4
        mov     r9, qword ptr [rdi]
        lea     r8, [r9 + 256]
        xor     edi, edi

        // Loop goes here
.LBB7_2:
        mov     ecx, dword ptr [r9 + 4*rdi]
        mov     dword ptr [rsi], edi
        lea     rax, [rsi + 4]
        cmp     ecx, edx
        cmovge  rax, rsi
        mov     ecx, dword ptr [r9 + 4*rdi + 4]
        lea     esi, [rdi + 1]
        mov     dword ptr [rax], esi
        lea     rsi, [rax + 4]
        cmp     ecx, edx
        cmovge  rsi, rax
        mov     eax, dword ptr [r9 + 4*rdi + 8]
        lea     ecx, [rdi + 2]
        mov     dword ptr [rsi], ecx
        lea     rcx, [rsi + 4]
        cmp     eax, edx
        cmovge  rcx, rsi
        mov     r10d, dword ptr [r9 + 4*rdi + 12]
        lea     esi, [rdi + 3]
        lea     rax, [r9 + 4*rdi + 16]
        add     rdi, 4
        mov     dword ptr [rcx], esi
        lea     rsi, [rcx + 4]
        cmp     r10d, edx
        cmovge  rsi, rcx
        // Conditional branch to the loop beginning
        cmp     rax, r8
        jne     .LBB7_2
        pop     rax
        ret
.LBB7_4:
        call    std::panicking::begin_panic
        ud2

正如你所看到的,循环被展开并且单个分支是循环迭代跳转。

然而,我很惊讶,这个函数没有被消除,因为它没有任何效果:它应该被编译成简单的noop。可能,在内联之后会这样做。

此外,我要说的是,将参数更改为&mut不会改变代码:

example::foo2:
        push    rax
        cmp     qword ptr [rdi + 16], 64
        jb      .LBB7_4
        mov     r9, qword ptr [rdi]
        lea     r8, [r9 + 256]
        xor     edi, edi
.LBB7_2:
        mov     ecx, dword ptr [r9 + 4*rdi]
        mov     dword ptr [rsi], edi
        lea     rax, [rsi + 4]
        cmp     ecx, edx
        cmovge  rax, rsi
        mov     ecx, dword ptr [r9 + 4*rdi + 4]
        lea     esi, [rdi + 1]
        mov     dword ptr [rax], esi
        lea     rsi, [rax + 4]
        cmp     ecx, edx
        cmovge  rsi, rax
        mov     eax, dword ptr [r9 + 4*rdi + 8]
        lea     ecx, [rdi + 2]
        mov     dword ptr [rsi], ecx
        lea     rcx, [rsi + 4]
        cmp     eax, edx
        cmovge  rcx, rsi
        mov     r10d, dword ptr [r9 + 4*rdi + 12]
        lea     esi, [rdi + 3]
        lea     rax, [r9 + 4*rdi + 16]
        add     rdi, 4
        mov     dword ptr [rcx], esi
        lea     rsi, [rcx + 4]
        cmp     r10d, edx
        cmovge  rsi, rcx
        cmp     rax, r8
        jne     .LBB7_2
        pop     rax
        ret
.LBB7_4:
        call    std::panicking::begin_panic
        ud2

很可能rustc在LLVM IR中将该函数表示为接受指针类型的缓冲区参数。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接