Java,查找数组中第K大的值

19

我曾经参加Facebook的面试,他们问了我这个问题。

假设您有一个具有N个不同值的无序数组

$input = [3,6,2,8,9,4,5]

实现一个函数,找到第K大的值。

例如:如果K = 0,则返回9。如果K = 1,则返回8。

我采用了以下方法。

private static int getMax(Integer[] input, int k)
{
    List<Integer> list = Arrays.asList(input);
    Set<Integer> set = new TreeSet<Integer>(list);

    list = new ArrayList<Integer>(set);
    int value = (list.size() - 1) - k;

    return list.get(value);
}
我刚刚测试了一下,根据这个问题,该方法运行良好。然而面试者说,“为了让你的生活更加复杂!假设你的数组包含数百万个数字,那么你的列表变得太慢了。在这种情况下你会怎么做?” 作为提示,他建议使用“最小堆”。就我所知,堆的每个子节点的值都不应超过根节点的值。所以,在这种情况下,如果我们假设3是根节点,那么6是它的子节点,而且其值大于根节点的值。我可能是错的,但你认为呢?基于“最小堆”的思想,该如何实现?

你为什么没有向面试官要求代码示例? - Olimpiu POP
在最小堆中,每个节点都小于或等于它的两个子节点。因此根节点应该是2而不是3。一种可能的布局是树2 -> [3,4],3 -> [5,6],4 -> [8,9] - paxdiablo
为什么要转换为TreeSet,然后再转回来,而不是直接调用Collections.sort呢? - user253751
@immibis 噢,是的,我不熟悉那个:( 我想从列表中删除所有重复项并按升序排序。我不确定 Collections.sort 是否会删除重复项! - Hesam
显示剩余2条评论
5个回答

19
他已经给出了完整的答案,不仅仅是一个提示。
而且你的理解是基于“最大堆”,而不是“最小堆”。它的工作方式是不言自明的。
在“最小堆”中,根节点具有最小值(小于其子节点)。
所以,你需要做的是,遍历数组并将K个元素填充到“最小堆”中。一旦完成,堆自动包含根节点上的最低元素。
现在,对于从数组中读取的每个“下一个”元素, -> 检查该值是否大于最小堆的根节点。 -> 如果是,则从最小堆中删除根节点,并将该值添加到其中。
当你遍历整个数组后,“最小堆”的根节点自动包含第k大的元素。
并且堆中的所有其他元素(确切地说是k-1个元素)都比k大。

谢谢,我找到了我的问题所在。会尝试找到正确的实现方式。 - Hesam
@Codebender 如果您能解释一下问题中给出的示例,那就太好了。我想学习它。谢谢。 - Uma Kanth
@UmaKanth,你想要堆实现的例子(谷歌上有很多例子)还是更详细地解释算法的其他部分? - Codebender
@UmaKanth,我已经提供了该问题的实现。请在答案中查看。 - YoungHobbit

5
这是使用Java中的PriorityQueue实现最小堆的代码。 复杂度:n * log k
import java.util.PriorityQueue;

public class LargestK {

  private static Integer largestK(Integer array[], int k) {
    PriorityQueue<Integer> queue = new PriorityQueue<Integer>(k+1);
    int i = 0;
    while (i<=k) {
      queue.add(array[i]);
      i++;
    }
    for (; i<array.length; i++) {
      Integer value = queue.peek();
      if (array[i] > value) {
        queue.poll();
        queue.add(array[i]);
      }
    }
    return queue.peek();
  }

  public static void main(String[] args) {
    Integer array[] = new Integer[] {3,6,2,8,9,4,5};
    System.out.println(largestK(array, 3));
  }
}

输出: 5

该代码循环遍历数组,时间复杂度为O(n)。优先队列(最小堆)的大小为k,因此任何操作都是log k。在最坏的情况下,即所有数字都按升序排序,时间复杂度为n*log k,因为对于每个元素,您需要删除堆顶并插入新元素。


@njzk2 复杂度是 n*log k,而不是 n*log n。代码循环遍历数组,时间复杂度为 O(n)PriorityQueue(最小堆)的大小为 k,因此任何操作都是 log k。在最坏的情况下,即所有数字按升序排序,复杂度为 n*log k。因为对于每个元素,您需要删除堆的顶部并插入新元素。 - YoungHobbit
1
没错,我没有考虑到 queue.poll(); 会确保队列始终为大小 k - njzk2

3

编辑:请查看此答案,其中提供了O(n)的解决方案。

您也可以使用PriorityQueue来解决这个问题:

public int findKthLargest(int[] nums, int k) {
        int p = 0;
        int numElements = nums.length;
        // create priority queue where all the elements of nums will be stored
        PriorityQueue<Integer> pq = new PriorityQueue<Integer>();

        // place all the elements of the array to this priority queue
        for (int n : nums){
            pq.add(n);
        }

        // extract the kth largest element
        while (numElements-k+1 > 0){
            p = pq.poll();
            k++;
        }

        return p;
    }

来自Java 文档:

实现说明:该实现提供了offerpollremove()add方法的O(log(n))时间;remove(Object)contains(Object)方法的线性时间;以及peekelementsize方法的常数时间。

以上算法的复杂度为O(nlogn),for循环运行n次。


谢谢@akhil,但是我认为如果你假设num数组包含一百万个项目,它仍然存在O(n)问题。你的ForWhile循环都是O(n)。或者我的知识混乱了:( - Hesam
此外,您的空间复杂度不必要地高,因为您将整个数组复制到队列中。 - Codebender
我要寻找的一个洞见是:由于我正在寻找第K大的值,如果我的堆/优先队列中已经有了K个元素,并且一个数字A进来了,它比那些最小的元素还要小--那么A肯定不会是最大的K个之一,因此不应该放入数据结构中。换句话说,数据结构永远不应该包含超过K个元素。这样,即使您有数千万个值--甚至数十亿!--并且它们正在流入(因为它们不能全部适合内存),如果K足够小,您仍然可以解决问题。 - yshavit
1
你的 while 循环是 O(n),每个 poll() 在其中都是 O(log(n))。因此,总复杂度为 O(nlog(n)) - Codebender
1
@immibis,我觉得很明显应该使用while循环,因为在polls的引用中确实使用了while循环。 - njzk2
显示剩余7条评论

0

如果数组/流中的元素数量未知,则基于堆的解决方案是完美的。但是,如果它们是有限的,但仍然希望在线性时间内获得优化的解决方案,该怎么办呢?

我们可以使用快速选择(在此处讨论过)。

数组 = [3,6,2,8,9,4,5]

让我们选择第一个元素作为枢轴:

枢轴= 3(在0号索引处),

现在将数组分区,使所有小于或等于3的元素位于左侧,大于3的数字位于右侧。就像在我的博客中讨论的快速排序一样。

因此,在第一次通过后 - [2,3,6,8,9,4,5]

枢轴索引为1(即它是第二个最低的元素)。现在再次应用相同的过程。

选择,现在是第6个,在上一个枢轴后的索引值为 - [2,3,4,5,6,8,9]

所以现在6已经在正确的位置上了。

继续检查是否已经找到了适当的数字(每次迭代中的第k大或第k小)。如果找到了,那么就完成了,否则继续。


0

对于常量值k的一种方法是使用部分插入排序。

(这假设有不同的值,但也可以轻松地改为处理重复值)

last_min = -inf
output = []
for i in (0..k)
    min = +inf
    for value in input_array
        if value < min and value > last_min
            min = value
    output[i] = min
print output[k-1]

(这是伪代码,但在Java中应该很容易实现)。

总体复杂度为O(n*k),这意味着它只有在k是常数或已知小于log(n)时才能很好地工作。

好的一面是,这是一个非常简单的解决方案。坏的一面是,它不如堆解决方案高效。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接