在Rust中计算两个f64向量的点积的最快方法是什么?

3

我正在用Rust编写神经网络的实现,尝试计算两个矩阵的点积。我有以下代码:

fn dot_product(a: Vec<f64>, b: Vec<f64>) -> f64 {
    // Calculate the dot product of two vectors.
    assert_eq!(a.len(), b.len());
    let mut product: f64 = 0.0;
    for i in 0..a.len() {
        product += a[i] * b[i];
    }
    product
}

这个操作需要两个向量 ab(长度相同),并逐元素进行乘法运算(将向量 a 中的值1与向量 b 中的值1相乘,然后将结果加到向量 a 中的值2和向量 b 中的值2上,以此类推...)。
如果有更高效的方法,请说明是什么以及如何实现?

看起来不错。在你确定它是瓶颈之前,可以使用它。如果你已经确定需要最快的速度,也许可以考虑一下 SIMD?SIMD - anderspitman
2
你可以使用迭代器在一行代码中完成,例如 a.into_iter().zip(b).map(|(a, b)| a*b).sum()。但我认为这种方法的速度与其他方法相比是相当快的,而不是显著更快或更慢。 - trent
3个回答

3

这并不是一个全面的通用答案,但我想分享一些代码。

你的实现看起来很像我会做的,除非我知道它是应用程序中的瓶颈。然后我会研究更深奥的方法(也许是 SIMD)。

话虽如此,你可以考虑将你的函数改为接受切片引用。这样你就可以传递Vec或数组:

fn dot_product(a: &[f64], b: &[f64]) -> f64 {
    // Calculate the dot product of two vectors. 
    assert_eq!(a.len(), b.len()); 
    let mut product = 0.0;
    for i in 0..a.len() {
        product += a[i] * b[i];
    }
    product
}

fn main() {
    println!("{}", dot_product(&[1.0,2.0], &[3.0,4.0]));
    println!("{}", dot_product(&vec![1.0,2.0], &vec![3.0,4.0]));
}

另请参阅:


0

在更深入地研究了@Ching-Chuan Chen的回答后,我认为这个问题的答案应该是:使用ndarrayblas。它比朴素的Rust实现快10倍。

你在dot周围看不到太多东西,因为通过包含ndarrayblas特性来完成大量工作。

    let x = Array1::random(d, Uniform::<f32>::new(0., 1.));
    let y = Array1::random(d, Uniform::<f32>::new(0., 1.));

    for _i in 0..n {
        let _res: f32 = x.dot(&y);
    }

需要注意的几点:1.比一个巨大向量上的点积更具代表性的是在较小向量上进行多个点积,因为应该按每个点计算sum,2.要超越打包在BLAS中的数十年线性代数研究将会非常困难。

0

我使用了rayonpacked_simd来计算点积,并找到了一种比Intel MKL更快的方法:

extern crate packed_simd;
extern crate rayon;
extern crate time;

use packed_simd::f64x4;
use packed_simd::f64x8;
use rayon::prelude::*;
use std::vec::Vec;

fn main() {
    let n = 100000000;
    let x: Vec<f64> = vec![0.2; n];
    let y: Vec<f64> = vec![0.1; n];

    let res: f64 = x
        .par_chunks(8)
        .map(f64x8::from_slice_unaligned)
        .zip(y.par_chunks(8).map(f64x8::from_slice_unaligned))
        .map(|(a, b)| a * b)
        .sum::<f64x8>()
        .sum();
    println!("res: {}", res);
}

我的Github上有这段代码。希望能对你有所帮助。


不需要使用 use std::vec::Vec;,它是预导入的一部分。为什么要使用 time crate?你能比较一下这个代码和其他答案的性能吗?为什么选择 unaligned - Shepmaster
我更深入地研究了这个基准测试,我认为它并不具有代表性,也不应该被使用。 - Alex Moore-Niemi

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接