如何高效地在数组中找到第一个非零元素?

9
假设我们想要快速查找数组中第一个非零元素的索引,可以这样做:
fn leading_zeros(arr: &[u32]) -> Option<usize> {
    arr.iter().position(|&x| x != 0)
}

然而,rustc 会将这个过程编译成逐个检查的方式,可以在这里看到。通过使用 u128 类型每次检查四个单词,可以略微提高速度。在我的机器上,这样做的速度提升约为3倍。

fn leading_zeros_wide(arr: &[u32]) -> Option<usize> {
    let (beg, mid, _) = unsafe { arr.align_to::<u128>() };

    beg.iter().position(|&x| x != 0).or_else(|| {
        let left = beg.len() + 4 * mid.iter().position(|&x| x != 0).unwrap_or(mid.len());
        arr[left..].iter().position(|&x| x != 0).map(|p| p + left)
    })
}

有没有什么方法可以让这个过程更快?


以下是我使用的基准测试,以确定3倍加速:

#![feature(test)]
extern crate test;

fn v() -> Box<[u32]> {
    std::iter::repeat(0).take(1000).collect()
}

// Assume `leading_zeros` and `leading_zeros_wide` are defined here.

#[bench]
fn bench_leading_zeros(b: &mut test::Bencher) {
    let v = test::black_box(v());
    b.iter(|| leading_zeros(&v[3..]))
}

#[bench]
fn bench_leading_zeros_wide(b: &mut test::Bencher) {
    let v = test::black_box(v());
    b.iter(|| leading_zeros_wide(&v[3..]))
}

1
@JohnKugelman 我没有使用 end 参数,因为切片 arr[left..] 包含了那部分内容。 - MERTON
7
我认为https://docs.rs/memx/latest/memx/fn.memnechr.html应该更快,更可靠。 - Stargateur
2
谢谢大家!不幸的是,memx crate 目前在 memnechr 方面存在一个 bug(至少在 0.1.18 版本中)。 - MERTON
2
我看到你的优化版本仍然没有使用SIMD,即使指定了编译器选项:https://rust.godbolt.org/z/8scnKToq8 这意味着它可以进一步优化。显然有一种方法可以直接使用CPU内部函数:x86arm。抱歉,我不会提供这个解决方案,因为我不懂Rust(我是通过[simd]标签看到这个问题的)。 - Alex Guteniev
1
我不知道如何在Rust中使用SIMD内嵌函数,但是你想要它在x86上发出的汇编指令是搜索包含非零元素的向量,然后使用Is there an efficient way to get the first non-zero element in an SIMD register using SIMD intrinsics?来查找该向量中的位置。就像我在Efficiently find least significant set bit in a large array?中对AVX2 C内嵌函数的回答一样(一旦找到非零元素,它会对其进行位扫描以找到位位置)。 - Peter Cordes
显示剩余7条评论
2个回答

4

64位:https://rust.godbolt.org/z/rsxh8P8Er

32位:https://rust.godbolt.org/z/3P3ejsnh1

我有一些Rust和汇编的经验,但我添加了一些测试。

#[cfg(target_feature = "avx2")]
pub mod avx2 {
    #[cfg(target_arch = "x86")]
    use std::arch::x86::*;
    #[cfg(target_arch = "x86_64")]
    use std::arch::x86_64::*;

    fn first_nonzero_tiny(arr: &[u32]) -> Option<usize> {
        arr.iter().position(|&x| x != 0)
    }

    fn find_u32_zeros_8elems(arr: &[u32], offset: isize) -> i32 {
        unsafe {
            let ymm0 = _mm256_setzero_si256();
            let mut ymm1 = _mm256_loadu_si256(arr.as_ptr().offset(offset) as *const __m256i);
            ymm1 = _mm256_cmpeq_epi32(ymm1, ymm0);
            let ymm2 = _mm256_castsi256_ps(ymm1);
            _mm256_movemask_ps(ymm2)
        }
    }

    pub fn first_nonzero(arr: &[u32]) -> Option<usize> {
        let size = arr.len();
        if size < 8 {
            return first_nonzero_tiny(arr);
        }

        let mut i: usize = 0;
        let simd_size = size / 8 * 8;
        while i < simd_size {
            let mask: i32 = find_u32_zeros_8elems(&arr, i as isize);
            //println!("mask = {}", mask);
            if mask != 255 {
                return Some((mask.trailing_ones() as usize) + i);
            }
            i += 8;
            //println!("i = {}", i);
        }

        let last_chunk = size - 8;
        let mask: i32 = find_u32_zeros_8elems(&arr, last_chunk as isize);
        if mask != 255 {
            return Some((mask.trailing_ones() as usize) + last_chunk);
        }

        None
    }
}

use avx2::first_nonzero;

pub fn main() {
    let v = [0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [2];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(0));

    let v = [1, 0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(0));

    let v = [0, 1, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(1));

    let v = [0, 0, 1, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(2));

    let v = [0, 0, 0, 1, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(3));

    let v = [0, 0, 0, 0, 1, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(4));

    let v = [0, 0, 0, 0, 0, 1, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(5));

    let v = [0, 0, 0, 0, 0, 1, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(5));

    let v = [0, 0, 0, 0, 0, 0, 1, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(6));

    let v = [0, 0, 0, 0, 0, 0, 0, 1, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(7));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 1];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(8));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [0, 0, 0, 0, 0, 0, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, None);

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(16));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(15));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 3, 4, 5];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(14));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(17));

    let v = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49];
    let test1 = first_nonzero(&v);
    assert_eq!(test1, Some(18));
}

如果使用 i += 8并将其用作指向 i32 的指针的偏移量,然后在将结果强制转换为 __m256i*,像 C 风格的 _mm_loadu_si128( (const __m128i*) &arr[i] ) 而不是 i + (const __m128i*)arr,那么处理会更方便。即使对于您当前的代码,这也可以让您使用结束时的 i 值,而不是 n*8 .. arr.len()。虽然那时您必须要做 i < n-7; 您正在通过 n / = 8; 解决向量超出问题,因此在这一点上,这只是两种同样有效的手动 SIMD 数组索引方式。 - Peter Cordes
谢谢@PeterCordes,我编辑了我的答案,现在它使用 _mm256_loadu_si256 而不是 _mm256_load_si256 - Igor Zhukov
有趣的是,-C target-cpu=haswell(https://rust.godbolt.org/z/cP4zn5hx1)确实使其使用了 tzcnt。但仍然存在无用的 movzx 在 movmskps 之后,甚至没有到不同的寄存器,从而破坏了 mov-elimination。但它确实将 tzcnt(或 bsf)写入了另一个寄存器,导致对 RDX 的错误依赖,而该函数执行路径之前并未写入该寄存器。(在 Skylake 之前的 SnB-family 上,BSF 总是具有输出依赖性,因此可以在输入为 0 时保持 dst 不变;TZCNT 则具有输出依赖性)。总之,这些都是 rustc / LLVM 未优化的问题,在源代码中无法解决。 - Peter Cordes
1
@PeterCordes,显然 target-feature=+avx2,bmi,bmi2 不会启用 tzcnt,你必须使用 target-feature=+avx2,+bmi(此处似乎不需要bmi2)。 - Alex Guteniev
1
@IgorZhukov 非常感谢!我会跟进我的机器上的测试结果(这需要时间,因为我需要将其移植到aarch64)。 - MERTON
显示剩余10条评论

1

这里有一个解决方案,比基准方案更快,但可能仍然有很大的提升空间。

以下方案比基准方案的 first_nonzero 快了7.5倍。

/// Finds the position of the first nonzero element in a given slice which
/// contains a nonzero.
///
/// # Safety
///
/// The caller *has* to ensure that the input slice has a nonzero.
unsafe fn first_nonzero_padded(arr: &[u32]) -> usize {
    let (beg, mid, _) = arr.align_to::<u128>();
    beg.iter().position(|&x| x != 0).unwrap_or_else(|| {
        let left = beg.len()
            + 4 * {
                let mut p: *const u128 = mid.as_ptr();
                loop {
                    if *p.offset(0) != 0 { break p.offset(0); }
                    if *p.offset(1) != 0 { break p.offset(1); }
                    if *p.offset(2) != 0 { break p.offset(2); }
                    if *p.offset(3) != 0 { break p.offset(3); }
                    if *p.offset(4) != 0 { break p.offset(4); }
                    if *p.offset(5) != 0 { break p.offset(5); }
                    if *p.offset(6) != 0 { break p.offset(6); }
                    if *p.offset(7) != 0 { break p.offset(7); }
                    p = p.offset(8);
                }.offset_from(mid.as_ptr()) as usize
            };
        if let Some(p) = arr[left..].iter().position(|&x| x != 0) {
            left + p
        } else {
            core::hint::unreachable_unchecked()
        }
    })
}

有没有办法在Godbolt上编译(第一个修订版)中尝试的SIMD版本?使用use core_simd::u64x2;等等?https://godbolt.org/z/E6ozdhdYc对我来说无法使用rustc nightly。如果速度较慢,很可能您的mask8x8 :: from_array([* p.offset(00)!= ZERO,直到07 ])未编译为单个SSE4.1pcmpeqq或其他内容。我不知道是否会花费大量标量工作将8个2位比较结果打包成单个mask8x8,或者更糟糕的是将这些2位结果布尔化为1位结果? - Peter Cordes
但无论如何,将其描述为pcmpeqd / tzcnt几乎肯定是虚假的,所以你删除它也就不足为奇了:P 我并不惊讶在16字节块上进行早期退出会更好一些;你希望内部循环花费大量工作来准备循环后的东西,以便对非零元素的位置进行排序。例如,如果您预计有长时间的零运行,则甚至可以将多个向量组合在一起,然后稍后单独重新检查它们。(按缓存线大小的块工作很好,特别是如果您的数据按64对齐) - Peter Cordes
你当前的代码正在对两个64位块进行标量OR,并根据设置的FLAGS进行分支。https://godbolt.org/z/6fMEvveMb - Peter Cordes

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接