在一个数组中找到最大值的索引的最快方法是什么?

12

我有一个f32类型的二维数组(来自ndarray::ArrayView2),我想在每一行中找到最大值的索引,并将索引值放入另一个数组中。

Python中的等效操作类似于:

import numpy as np

for i in range (0, max_val, batch_size):
   sims = xp.dot(batch, vectors.T) 
   # sims is the dot product of batch and vectors.T
   # the shape is, for example, (1024, 10000)

   best_rows[i: i+batch_size] = sims.argmax(axis = 1)

在Python中,.argmax函数非常快速,但我没有发现Rust中有类似的功能。有什么最快的方法可以实现吗?


1
它实际上是一个数组还是一个Vec - jhpratt
1
你需要在这里展示一些Rust代码以提供上下文。 - tadman
@jhpratt 正如问题中所提到的,它是一个数组。 - user9773683
1
是的,但很多时候,人们在实际使用向量时会说数组。您的澄清加上实际类型的说明是有帮助的。 - jhpratt
2个回答

7
考虑一般的Ord类型的简单情况:答案会略有不同,具体取决于您是否知道这些值是Copy的,但以下是代码:
fn position_max_copy<T: Ord + Copy>(slice: &[T]) -> Option<usize> {
    slice.iter().enumerate().max_by_key(|(_, &value)| value).map(|(idx, _)| idx)
}

fn position_max<T: Ord>(slice: &[T]) -> Option<usize> {
    slice.iter().enumerate().max_by(|(_, value0), (_, value1)| value0.cmp(value1)).map(|(idx, _)| idx)
}

基本思路是将数组中的每个项(实际上是一个切片——无论它是Vec还是数组或其他更奇特的东西)与其索引配对,使用std::iter::Iterator函数仅根据值(而不是索引)查找最大值,然后仅返回索引。如果切片为空,则返回None。根据文档,将返回最右侧的索引;如果需要最左侧的索引,请在enumerate()之后执行rev()。

rev()enumerate()max_by_key()max_by()这里有文档说明;slice::iter()这里有文档说明(但作为rust开发人员,这应该是你需要记住的东西之一);mapOption::map(),在这里有文档说明(同上)。哦,而cmpOrd::cmp,但大多数情况下,您可以使用不需要它的Copy版本(例如,如果您正在比较整数)。


现在有一个问题: 由于IEEE浮点数的工作方式,f32并不是Ord。大多数语言都忽略了这一点,并且算法存在微妙的错误。提供Ord总序(通过将所有NaN声明为相等,并大于所有数字)最流行的箱子似乎是ordered-float。假设它实现正确,它应该非常轻量级。它确实调用了num_traits,但这是最流行的数字库的一部分,因此可能已经被其他依赖项调用。

您可以通过在切片迭代器上映射ordered_float::OrderedFloat(元组类型的“构造函数”)来在这种情况下使用它(slice.iter().map(ordered_float::OrderedFloat))。由于只想要最大元素的位置,因此无需在之后提取f32。


请注意,这是针对一维向量的,但 OP 正在使用二维数组,因此他需要迭代他的数组行并为每一行调用 position_max - Jmb
3
对于一维情况,另一个选项是(0..slice.len()).max_by_key(|i| &slice[i])。(我没有测试过,但无论T是否为Copy都应该有效。) - Sven Marnach
是的,老实说那可能更容易理解。对于T: !Copymax_by_key问题是它(隐式地)要求返回类型为Ord + 'static;我不确定是否在算法中实际上确实需要这个或者可能是一个疏忽。 - David A

4

David A的方法很酷,但是正如提到的那样,有一个问题: f32f64没有实现Ord::cmp. (真的让人头疼.)

有多种解决方法: 你可以自己实现cmp,或者使用ordered-float等.

在我的情况下,这是一个更大项目的一部分,我们非常谨慎地使用外部包。此外,我很确定我们没有任何NaN值。因此,我更喜欢使用fold,如果你仔细看max_by_key源代码,你会发现他们也在使用它。

for (i, row) in matrix.axis_iter(Axis(1)).enumerate() {
    let (max_idx, max_val) =
        row.iter()
            .enumerate()
            .fold((0, row[0]), |(idx_max, val_max), (idx, val)| {
                if &val_max > val {
                    (idx_max, val_max)
                } else {
                    (idx, *val)
                }
            });
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接