Rust - 如何在集合中找到第n个最频繁的元素

4

我无法想象这之前没有人问过这个问题,但我已经搜索了各个地方,却找不到答案。

我有一个可迭代对象,其中包含重复元素。我想计算每个元素在此可迭代对象中出现的次数,并返回第 n 个最频繁出现的元素。

我有一个可以完成这个任务的工作代码,但我确实怀疑它是否是实现这个任务最优方式。

use std::collections::{BinaryHeap, HashMap};

// returns n-th most frequent element in collection
pub fn most_frequent<T: std::hash::Hash + std::cmp::Eq + std::cmp::Ord>(array: &[T], n: u32) -> &T {
    // intialize empty hashmap
    let mut map = HashMap::new();

    // count occurence of each element in iterable and save as (value,count) in hashmap
    for value in array {
        // taken from https://doc.rust-lang.org/std/collections/struct.HashMap.html#method.entry
        // not exactly sure how this works
        let counter = map.entry(value).or_insert(0);
        *counter += 1;
    }

    // determine highest frequency of some element in the collection
    let mut heap: BinaryHeap<_> = map.values().collect();
    let mut max = heap.pop().unwrap();
    // get n-th largest value
    for _i in 1..n {
        max = heap.pop().unwrap();
    }

    // find that element (get key from value in hashmap)
    // taken from https://dev59.com/uLjna4cB1Zd3GeqP5y4z
    map.iter()
        .find_map(|(key, &val)| if val == *max { Some(key) } else { None })
        .unwrap()
}

有没有更好的方法或更优化的 std 方法来实现我想要的?或者说有一些社区制作的 Crate 可以使用。


我完全不了解Rust。我只是想知道是否有一种方法可以在你的集合上执行GroupBy和Count操作?按Count排序将允许您通过索引引用该项。 - Captain Kenpachi
1
@Ach113 不确定哪个更好,但 BinaryHeap::into_sorted_vec() 可以避免重复弹出堆(这可能不是很好),或者您可以将 hashmap 收集到 (count, item) 的 vec 中,然后按计数排序,并直接获取第 n 个元素。 - Masklinn
@Ach113,你只返回了一个值,所以我不明白你的意思,n-th对我来说不清楚,而且你说的话和你的代码不符。 - Stargateur
@Stargateur,我的意思是像在数组中找到第三大的元素(如果n=3),然后返回它。我不确定是否可以在O(n)的时间内完成(而不是始终跟踪3个变量)。 - Ach113
我会使用二叉堆最小值,只保留其中的n个较大值,并在哈希表中保存该项的引用,而不仅仅是计数。 - Stargateur
显示剩余4条评论
1个回答

2
您好!这段文本的翻译如下:
您的实现具有Ω(n log n)的时间复杂度,其中n是数组的长度。解决此问题的最优解具有检索第k个最频繁元素的Ω(n log k)的复杂度。通常的最佳解实现确实涉及二进制堆,但不是您使用它的方式。
以下是常见算法的建议实现:
use std::cmp::{Eq, Ord, Reverse};
use std::collections::{BinaryHeap, HashMap};
use std::hash::Hash;

pub fn most_frequent<T>(array: &[T], k: usize) -> Vec<(usize, &T)>
where
    T: Hash + Eq + Ord,
{
    let mut map = HashMap::new();
    for x in array {
        *map.entry(x).or_default() += 1;
    }

    let mut heap = BinaryHeap::with_capacity(k + 1);
    for (x, count) in map.into_iter() {
        heap.push(Reverse((count, x)));
        if heap.len() > k {
            heap.pop();
        }
    }
    heap.into_sorted_vec().into_iter().map(|r| r.0).collect()
}

(游乐场)

我更改了函数的原型,使其返回一个向量,其中包含出现频率最高的k个元素以及它们的计数,因为这正是你需要跟踪的。如果你只想要第k个最常见的元素,可以用[k-1][1]索引结果。

算法本身首先像你的代码一样构建一个元素计数映射,只是我写得更简洁。

接下来,我们为最常见的元素构建一个BinaryHeap。每次迭代后,此堆最多包含k个元素,这些元素是到目前为止看到的最常见的元素。如果堆中有超过k个元素,我们就会删除最不常见的元素。由于我们总是删除到目前为止最不常见的元素,因此堆始终保留到目前为止最常见的k个元素。我们需要使用Reverse包装器来获取最小堆,如在BinaryHeap文档所述的文档中所述。
最后,我们将结果收集到一个向量中。 into_sorted_vec()函数基本上为我们完成了这项工作,但我们仍然希望从其Reverse包装器中解压出项目-该包装器是我们函数的实现细节,不应返回给调用者。
在Rust Nightly中,我们还可以使用 into_iter_sorted()方法,节省一个向量分配的开销。
这个答案中的代码确保堆基本上被限制为k个元素,因此对堆的插入具有Ω(log k)的复杂度。在你的代码中,你一次性将数组中的所有元素都推送到堆中,没有限制堆的大小,因此你最终得到了插入的Ω(log n)的复杂度。你实际上是使用二叉堆对计数列表进行排序。这种方法可以工作,但它肯定不是最简单或最快的方式来实现这个目标,因此很少有理由采用这种方法。

等一下...为什么(usize, &T)是有序的?我们只想比较usize... - Stargateur
1
这里有另一种变体,按照频率和原始数组中的位置排序,因此您不需要 T: Ord。(尽管如果 T 必须是 Ord,您可以使用 BTreeMap 替代并消除 T: Hash...) - trent
1
@Stargateur 我并没有尝试包含微小的优化,而是尽可能保持代码简单易读。代码比较 (usize, &T) 对是因为它并不重要。我首先引入了一个自定义的 Item 结构体,其中包含计数和对项目的引用,并添加了一个自定义的 PartialEq 实现,仅比较计数,但需要更多的代码来解释,唯一的好处是我们可以放弃 T: Ord,这已经在问题中被 OP 包含了。 - Sven Marnach
显示剩余3条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接