如何为一个结构体实现 Fn trait 以支持不同类型的参数?

11

我有一个简单的分类器:

struct Clf {
    x: f64,
}
分类器如果观测值小于x返回0,大于x返回1。我想为这个分类器实现调用运算符,但函数应能够接受浮点数或向量作为参数。对于向量,输出是一个大小与输入向量相同的0或1向量。
let c = Clf { x: 0 };
let v = vec![-1, 0.5, 1];
println!("{}", c(0.5));   // prints 1
println!("{}", c(v));     // prints [0, 1, 1]

在这种情况下,我应该如何编写Fn的实现?

impl Fn for Clf {
    extern "rust-call" fn call(/*...*/) {
        // ...
    }
}
4个回答

14
短答案是:你不能。至少它不会按照你想要的方式工作。我认为展示这一点的最佳方法是通过演示发生的事情,但总体思路是Rust不支持函数重载。
对于这个例子,我们将实现FnOnce,因为Fn需要FnMut,而FnMut需要FnOnce。所以,如果我们把这一切都搞定了,我们就可以为其他函数特性做到这一点。
首先,这是不稳定的,所以我们需要一些功能标志。
#![feature(unboxed_closures, fn_traits)]

那么,让我们为获取 f64 值编写 impl

impl FnOnce<(f64,)> for Clf {
    type Output = i32;
    extern "rust-call" fn call_once(self, args: (f64,)) -> i32 {
        if args.0 > self.x {
            1
        } else {
            0
        }
    }
}
< p > Fn特质系列的参数是通过元组提供的,因此这就是(f64,)语法;它是一个只有一个元素的元组。

这很好,我们现在可以执行c(0.5),虽然它会消耗c,直到我们实现其他特质为止。

现在让我们对Vec做同样的事情:

impl FnOnce<(Vec<f64>,)> for Clf {
    type Output = Vec<i32>;
    extern "rust-call" fn call_once(self, args: (Vec<f64>,)) -> Vec<i32> {
        args.0
            .iter()
            .map(|&f| if f > self.x { 1 } else { 0 })
            .collect()
    }
}

Rust 1.33 nightly之前,您无法直接调用c(v)或者c(0.5)(这在之前是可行的),因为我们会遇到函数类型未知的错误。基本上,这些版本的Rust不支持函数重载。但是我们仍然可以使用完全限定语法调用这些函数,其中c(0.5)变成了FnOnce::call_once(c, (0.5,))


不了解你的整体情况,我想通过以下方式简单地为Clf提供两个函数来解决这个问题:

impl Clf {
    fn classify(&self, val: f64) -> u32 {
        if val > self.x {
            1
        } else {
            0
        }
    }

    fn classify_vec(&self, vals: Vec<f64>) -> Vec<u32> {
        vals.into_iter().map(|v| self.classify(v)).collect()
    }
}

然后你的用法示例变成了。
let c = Clf { x: 0 };
let v = vec![-1, 0.5, 1];
println!("{}", c.classify(0.5));   // prints 1
println!("{}", c.classify_vec(v)); // prints [0, 1, 1]

我希望您可以将第二个函数classify_slice变得更加通用,接受&[f64]参数,这样您就可以通过引用来使用它与Vec一起使用:c.classify_slice(&v)。请保留HTML标签。

我想问一下,这种语法现在是过时的(历史性的)吗? - wizzwizz4

12
这确实是可能的,但需要一个新的特性和大量的混乱。
如果您从抽象开始
enum VecOrScalar<T> {
    Scalar(T),
    Vector(Vec<T>),
}

use VecOrScalar::*;

您希望找到一种使用类型转换的方法。
T      (hidden) -> VecOrScalar<T> -> T      (known)
Vec<T> (hidden) -> VecOrScalar<T> -> Vec<T> (known)

因为这样你可以将一个“隐藏”的类型T,用VecOrScalar包装起来,并使用match提取真实的类型T
你还需要...
T      (known) -> bool      = T::Output
Vec<T> (known) -> Vec<bool> = Vec<T>::Output

但是如果没有高阶类型,这有点棘手。相反,您可以这样做:

T      (known) -> VecOrScalar<T> -> T::Output
Vec<T> (known) -> VecOrScalar<T> -> Vec<T>::Output

如果您允许存在可能会出现panic的分支。
因此,该特质将是:
trait FromVecOrScalar<T> {
    type Output;

    fn put(self) -> VecOrScalar<T>;

    fn get(out: VecOrScalar<bool>) -> Self::Output;
}

使用实现

impl<T> FromVecOrScalar<T> for T {
    type Output = bool;

    fn put(self) -> VecOrScalar<T> {
        Scalar(self)
    }

    fn get(out: VecOrScalar<bool>) -> Self::Output {
        match out {
            Scalar(val) => val,
            Vector(_) => panic!("Wrong output type!"),
        }
    }
}

impl<T> FromVecOrScalar<T> for Vec<T> {
    type Output = Vec<bool>;

    fn put(self) -> VecOrScalar<T> {
        Vector(self)
    }

    fn get(out: VecOrScalar<bool>) -> Self::Output {
        match out {
            Vector(val) => val,
            Scalar(_) => panic!("Wrong output type!"),
        }
    }
}

你的类型

#[derive(Copy, Clone)]
struct Clf {
    x: f64,
}

首先实现这两个分支:

impl Clf {
    fn calc_scalar(self, f: f64) -> bool {
        f > self.x
    }

    fn calc_vector(self, v: Vec<f64>) -> Vec<bool> {
        v.into_iter().map(|x| self.calc_scalar(x)).collect()
    }
}

然后它将通过为 T: FromVecOrScalar<f64> 实现 FnOnce 来进行调度。

impl<T> FnOnce<(T,)> for Clf
where
    T: FromVecOrScalar<f64>,
{

使用类型

    type Output = T::Output;
    extern "rust-call" fn call_once(self, (arg,): (T,)) -> T::Output {

首先,调度程序将私有类型封装起来,因此您可以使用enum提取它,然后使用T::get获取结果,以再次隐藏它。

        match arg.put() {
            Scalar(scalar) => T::get(Scalar(self.calc_scalar(scalar))),
            Vector(vector) => T::get(Vector(self.calc_vector(vector))),
        }
    }
}

然后,成功:
fn main() {
    let c = Clf { x: 0.0 };
    let v = vec![-1.0, 0.5, 1.0];
    println!("{}", c(0.5f64));
    println!("{:?}", c(v));
}

由于编译器可以看穿这些废话,它实际上编译成与直接调用calc_方法基本相同的汇编代码。

这并不意味着写起来很好。像这样的重载是痛苦的、易碎的,并且绝对是一个坏主意™。不要这样做,但知道你可以这样做也没问题。


5
你不能(但请读完回答)。
首先,明确实现Fn*系列特质是不稳定的,随时可能发生变化,因此依赖于它是一个坏主意。
其次,更重要的是,在Rust 1.33 nightly之前,Rust编译器就无法让你调用具有不同参数类型的Fn*实现的值。因为通常情况下它无法确定你想要做什么。唯一的解决方法是完全指定要调用的特质,但这时你已经失去了任何可能的人机工效性优势。
与其尝试使用Fn*特质,不如定义和实现自己的特质。我在问题上做了一些改动以避免/修复可疑的方面。
struct Clf {
    x: f64,
}

trait ClfExt<T: ?Sized> {
    type Result;
    fn classify(&self, arg: &T) -> Self::Result;
}

impl ClfExt<f64> for Clf {
    type Result = bool;
    fn classify(&self, arg: &f64) -> Self::Result {
        *arg > self.x
    }
}

impl ClfExt<[f64]> for Clf {
    type Result = Vec<bool>;
    fn classify(&self, arg: &[f64]) -> Self::Result {
        arg.iter().map(|v| self.classify(v)).collect()
    }
}

fn main() {
    let c = Clf { x: 0.0 };
    let v = vec![-1.0, 0.5, 1.0];
    println!("{}", c.classify(&0.5f64));
    println!("{:?}", c.classify(&v[..]));
}

如何使用Fn*特质

我包含了这个内容是为了完整性;实际上不要这样做。 不仅不被支持,而且非常丑陋。

#![feature(fn_traits, unboxed_closures)]

#[derive(Copy, Clone)]
struct Clf {
    x: f64,
}

impl FnOnce<(f64,)> for Clf {
    type Output = bool;
    extern "rust-call" fn call_once(self, args: (f64,)) -> Self::Output {
        args.0 > self.x
    }
}

impl<'a> FnOnce<(&'a [f64],)> for Clf {
    type Output = Vec<bool>;
    extern "rust-call" fn call_once(self, args: (&'a [f64],)) -> Self::Output {
        args.0
            .iter()
            .cloned()
            .map(|v| FnOnce::call_once(self, (v,)))
            .collect()
    }
}

fn main() {
    let c = Clf { x: 0.0 };
    let v = vec![-1.0, 0.5, 1.0];

    // Before 1.33 nightly
    println!("{}", FnOnce::call_once(c, (0.5f64,)));
    println!("{:?}", FnOnce::call_once(c, (&v[..],)));

    // After
    println!("{}", c(0.5f64));
    println!("{:?}", c(&v[..]));
}

-2

您可以使用夜间版和不稳定的功能:

#![feature(fn_traits, unboxed_closures)]
struct Clf {
    x: f64,
}

impl FnOnce<(f64,)> for Clf {
    type Output = i32;
    extern "rust-call" fn call_once(self, args: (f64,)) -> i32 {
        if args.0 > self.x {
            1
        } else {
            0
        }
    }
}

impl FnOnce<(Vec<f64>,)> for Clf {
    type Output = Vec<i32>;
    extern "rust-call" fn call_once(self, args: (Vec<f64>,)) -> Vec<i32> {
        args.0
            .iter()
            .map(|&f| if f > self.x { 1 } else { 0 })
            .collect()
    }
}

fn main() {
    let c = Clf { x: 0.0 };
    let v = vec![-1.0, 0.5, 1.0];
    println!("{:?}", c(0.5));

    let c = Clf { x: 0.0 };
    println!("{:?}", c(v));
}

1
这个答案似乎已经由现有答案提供。在提供新答案之前,请仔细阅读所有现有答案。如果您的答案确实不同,请清楚地区分出其独特之处。 - Shepmaster

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接