如何编写一个通用函数,可以将ndarray数组或ArrayView作为输入参数?

4
我正在使用 ndarray 编写一组数学函数,希望可以在任何类型的 ArrayBase 上运行。然而,我在指定涉及到的特征/类型时遇到了问题。
这个基本函数可以处理 OwnedReprViewRepr 数据:
use ndarray::{prelude::*, Data}; // 0.13.1

fn sum_owned(x: Array<f64, Ix1>) -> f64 {
    x.sum()
}

fn sum_view(x: ArrayView<f64, Ix1>) -> f64 {
    x.sum()
}

fn main() {
    let a = Array::from_shape_vec((4,), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
    println!("{:?}", sum_owned(a.clone()));

    let b = a.slice(s![..]);
    println!("{:?}", sum_view(b));

    // Complains that OwnedRepr is not ViewRepr
    //println!("{:?}", sum_view(a.clone()));
}

我能理解为什么被注释掉的部分无法编译,但是我对泛型理解不够深入,不能写出更通用的代码。

这是我尝试过的:

use ndarray::prelude::*;
use ndarray::Data;

fn sum_general<S>(x: ArrayBase<S, Ix1>) -> f64
where
    S: Data,
{
    x.sum()
}

编译器错误提示“Data”不够具体,但我无法解析它以找出解决方案。
error[E0277]: the trait bound `<S as ndarray::data_traits::RawData>::Elem: std::clone::Clone` is not satisfied
 --> src/lib.rs:8:7
  |
6 |     S: Data,
  |             - help: consider further restricting the associated type: `, <S as ndarray::data_traits::RawData>::Elem: std::clone::Clone`
7 | {
8 |     x.sum()
  |       ^^^ the trait `std::clone::Clone` is not implemented for `<S as ndarray::data_traits::RawData>::Elem`

error[E0277]: the trait bound `<S as ndarray::data_traits::RawData>::Elem: num_traits::identities::Zero` is not satisfied
 --> src/lib.rs:8:7
  |
6 |     S: Data,
  |             - help: consider further restricting the associated type: `, <S as ndarray::data_traits::RawData>::Elem: num_traits::identities::Zero`
7 | {
8 |     x.sum()
  |       ^^^ the trait `num_traits::identities::Zero` is not implemented for `<S as ndarray::data_traits::RawData>::Elem`

error[E0308]: mismatched types
 --> src/lib.rs:8:5
  |
4 | fn sum_general<S>(x: ArrayBase<S, Ix1>) -> f64
  |                                            --- expected `f64` because of return type
...
8 |     x.sum()
  |     ^^^^^^^ expected `f64`, found associated type
  |
  = note:         expected type `f64`
          found associated type `<S as ndarray::data_traits::RawData>::Elem`
  = note: consider constraining the associated type `<S as ndarray::data_traits::RawData>::Elem` to `f64`
  = note: for more information, visit https://doc.rust-lang.org/book/ch19-03-advanced-traits.html

当您遵循编译器的建议(“help: consider...” / “note: consider ...”)时会发生什么? - Shepmaster
@Shepmaster 感谢您的提示。我确实尝试添加了Elem = A,但似乎引发了一系列新问题(可能与此无关),所以我决定在这里提出更一般的问题。我现在看到最相关的部分是“考虑将关联类型<S as ndarray::data_traits::RawData> ::Elem限制为f64”,但我只是无法将其转化为实际代码... - ssokolen
1个回答

3

如果你查看你试图调用的ndarray::ArrayBase::sum函数的定义:

impl<A, S, D> ArrayBase<S, D>
where
    S: Data<Elem = A>,
    D: Dimension,
{
    pub fn sum(&self) -> A
    where
       A: Clone + Add<Output = A> + Zero
    {
         // etc.
    }
}

很明显,在您的情况下,A = f64D = Ix1,但您仍需要指定约束条件 S: Data<Elem = f64>。因此:

use ndarray::prelude::*;
use ndarray::Data;

fn sum_general<S>(x: ArrayBase<S, Ix1>) -> f64
where
    S: Data<Elem = f64>,
{
    x.sum()
}

这正是编译器建议时的意思:

  = 注意:期望类型为 `f64`
          发现关联类型 `<S as ndarray::data_traits::RawData>::Elem`
  = 注意:考虑将关联类型 `<S as ndarray::data_traits::RawData>::Elem` 约束为 `f64`

太好了!我确实尝试过 Elem = A,但那似乎只会引起更多问题,所以我放弃了那个想法,认为自己可能漏掉了什么......我没有想到只需在我的上下文中填写A即可。 - ssokolen

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接