如何实现一个中位数堆

46

我希望实现一个Median-heap来跟踪给定整数集合的中位数,类似于Max-heap和Min-heap。API应该具有以下三个函数:

insert(int)  // should take O(logN)
int median() // will be the topmost element of the heap. O(1)
int delmedian() // should take O(logN)

我想使用数组(a)实现堆,其中数组索引k的子节点存储在数组索引2*k和2*k + 1中。为方便起见,该数组从索引1开始填充元素。

到目前为止,我已经做了这些:

中位数堆将有两个整数来跟踪迄今插入的大于当前中位数(gcm)和小于当前中位数(lcm)的整数数量。

if abs(gcm-lcm) >= 2 and gcm > lcm we need to swap a[1] with one of its children. 
The child chosen should be greater than a[1]. If both are greater, 
choose the smaller of two.

同样适用于另一种情况。我无法想出如何对元素进行下沉和上浮的算法。我认为它应该考虑到数字与中位数的接近程度,因此可以使用以下算法:

private void swim(int k) {
    while (k > 1 && absless(k, k/2)) {   
        exch(k, k/2);
        k = k/2;
    }
}

虽然我不能提供完整的解决方案。


1
没有给定值的多重性限制,这将变得困难。 - greybeard
6个回答

188
你需要两个堆:一个最小堆和一个最大堆。每个堆包含约一半的数据。最小堆中的每个元素都大于或等于中位数,而最大堆中的每个元素都小于或等于中位数。
当最小堆包含比最大堆多一个元素时,中位数在最小堆的顶部。当最大堆包含比最小堆多一个元素时,中位数在最大堆的顶部。
当两个堆包含相同数量的元素时,总元素数量是偶数。在这种情况下,您需要根据中位数的定义进行选择:a) 两个中间元素的平均值; b) 两者中较大的那个; c) 两者中较小的那个; d) 随机选择其中任意一个...
每次插入时,将新元素与堆顶上的元素进行比较,以便决定将其插入到何处。如果新元素大于当前中位数,则将其放入最小堆。如果它小于当前中位数,则将其放入最大堆。然后你可能需要重新平衡。如果堆的大小相差超过一个元素,则从元素更多的堆中提取最小/最大值并将其插入另一个堆中。
为了构建元素列表的中位数堆,我们应首先使用线性时间算法找到中位数。一旦知道了中位数,我们可以根据中位数值简单地向最小堆和最大堆添加元素。不需要平衡堆,因为中位数将把输入元素列表分成相等的两半。
如果您提取一个元素,则可能需要通过将一个元素从一个堆移动到另一个堆来补偿大小更改。这样,您可以确保在任何时候,两个堆具有相同的大小或仅相差一个元素。

4
那么元素的总数是偶数。根据您对中位数的定义进行操作:a)始终选择较低的值;b)始终选择较高的值;c)随机选择;d)中位数是这两个中间元素的平均值。 - comocomocomocomo
1
当你考虑它时,d)中位数解决方案有点奇怪,因为当你在堆上调用 remove() 时,你希望它给你实际删除的元素。如果你通过平均计算中位数,那么 remove() 将返回一个数字(你计算的中位数),但实际上会删除另一个数字。因此,如果您将n个元素添加到这种MedianHeap中,然后对每个元素进行removeMedian并将其放入另一种数据结构中,第二种数据结构中的元素将与进入MedianHeap的元素不同。 - angelatlarge
1
为了完整起见,我们还应该补充一点:为了构建元素列表的中位数堆,我们首先应该使用线性时间算法找到中位数。一旦知道了中位数,我们可以根据中位数值简单地将元素添加到最小堆和最大堆中。不需要平衡堆,因为中位数将把输入元素列表分成相等的两半。 - isubuz
2
一个线性时间算法,用于查找中位数:http://en.wikipedia.org/wiki/Median_of_medians - Erdem
1
当小根堆的元素数比大根堆多一个时,中位数在小根堆的顶部。小根堆的顶部应该是最小的数字吗? - cmal
显示剩余4条评论

11

这是一个Java实现的MedianHeap,借助以上comocomocomocomo的解释开发而成。

import java.util.Arrays;
import java.util.Comparator;
import java.util.PriorityQueue;
import java.util.Scanner;

/**
 *
 * @author BatmanLost
 */
public class MedianHeap {

    //stores all the numbers less than the current median in a maxheap, i.e median is the maximum, at the root
    private PriorityQueue<Integer> maxheap;
    //stores all the numbers greater than the current median in a minheap, i.e median is the minimum, at the root
    private PriorityQueue<Integer> minheap;

    //comparators for PriorityQueue
    private static final maxHeapComparator myMaxHeapComparator = new maxHeapComparator();
    private static final minHeapComparator myMinHeapComparator = new minHeapComparator();

    /**
     * Comparator for the minHeap, smallest number has the highest priority, natural ordering
     */
    private static class minHeapComparator implements Comparator<Integer>{
        @Override
        public int compare(Integer i, Integer j) {
            return i>j ? 1 : i==j ? 0 : -1 ;
        }
    }

    /**
     * Comparator for the maxHeap, largest number has the highest priority
     */
    private static  class maxHeapComparator implements Comparator<Integer>{
        // opposite to minHeapComparator, invert the return values
        @Override
        public int compare(Integer i, Integer j) {
            return i>j ? -1 : i==j ? 0 : 1 ;
        }
    }

    /**
     * Constructor for a MedianHeap, to dynamically generate median.
     */
    public MedianHeap(){
        // initialize maxheap and minheap with appropriate comparators
        maxheap = new PriorityQueue<Integer>(11,myMaxHeapComparator);
        minheap = new PriorityQueue<Integer>(11,myMinHeapComparator);
    }

    /**
     * Returns empty if no median i.e, no input
     * @return
     */
    private boolean isEmpty(){
        return maxheap.size() == 0 && minheap.size() == 0 ;
    }

    /**
     * Inserts into MedianHeap to update the median accordingly
     * @param n
     */
    public void insert(int n){
        // initialize if empty
        if(isEmpty()){ minheap.add(n);}
        else{
            //add to the appropriate heap
            // if n is less than or equal to current median, add to maxheap
            if(Double.compare(n, median()) <= 0){maxheap.add(n);}
            // if n is greater than current median, add to min heap
            else{minheap.add(n);}
        }
        // fix the chaos, if any imbalance occurs in the heap sizes
        //i.e, absolute difference of sizes is greater than one.
        fixChaos();
    }

    /**
     * Re-balances the heap sizes
     */
    private void fixChaos(){
        //if sizes of heaps differ by 2, then it's a chaos, since median must be the middle element
        if( Math.abs( maxheap.size() - minheap.size()) > 1){
            //check which one is the culprit and take action by kicking out the root from culprit into victim
            if(maxheap.size() > minheap.size()){
                minheap.add(maxheap.poll());
            }
            else{ maxheap.add(minheap.poll());}
        }
    }
    /**
     * returns the median of the numbers encountered so far
     * @return
     */
    public double median(){
        //if total size(no. of elements entered) is even, then median iss the average of the 2 middle elements
        //i.e, average of the root's of the heaps.
        if( maxheap.size() == minheap.size()) {
            return ((double)maxheap.peek() + (double)minheap.peek())/2 ;
        }
        //else median is middle element, i.e, root of the heap with one element more
        else if (maxheap.size() > minheap.size()){ return (double)maxheap.peek();}
        else{ return (double)minheap.peek();}

    }
    /**
     * String representation of the numbers and median
     * @return 
     */
    public String toString(){
        StringBuilder sb = new StringBuilder();
        sb.append("\n Median for the numbers : " );
        for(int i: maxheap){sb.append(" "+i); }
        for(int i: minheap){sb.append(" "+i); }
        sb.append(" is " + median()+"\n");
        return sb.toString();
    }

    /**
     * Adds all the array elements and returns the median.
     * @param array
     * @return
     */
    public double addArray(int[] array){
        for(int i=0; i<array.length ;i++){
            insert(array[i]);
        }
        return median();
    }

    /**
     * Just a test
     * @param N
     */
    public void test(int N){
        int[] array = InputGenerator.randomArray(N);
        System.out.println("Input array: \n"+Arrays.toString(array));
        addArray(array);
        System.out.println("Computed Median is :" + median());
        Arrays.sort(array);
        System.out.println("Sorted array: \n"+Arrays.toString(array));
        if(N%2==0){ System.out.println("Calculated Median is :" + (array[N/2] + array[(N/2)-1])/2.0);}
        else{System.out.println("Calculated Median is :" + array[N/2] +"\n");}
    }

    /**
     * Another testing utility
     */
    public void printInternal(){
        System.out.println("Less than median, max heap:" + maxheap);
        System.out.println("Greater than median, min heap:" + minheap);
    }

    //Inner class to generate input for basic testing
    private static class InputGenerator {

        public static int[] orderedArray(int N){
            int[] array = new int[N];
            for(int i=0; i<N; i++){
                array[i] = i;
            }
            return array;
        }

        public static int[] randomArray(int N){
            int[] array = new int[N];
            for(int i=0; i<N; i++){
                array[i] = (int)(Math.random()*N*N);
            }
            return array;
        }

        public static int readInt(String s){
            System.out.println(s);
            Scanner sc = new Scanner(System.in);
            return sc.nextInt();
        }
    }

    public static void main(String[] args){
        System.out.println("You got to stop the program MANUALLY!!");        
        while(true){
            MedianHeap testObj = new MedianHeap();
            testObj.test(InputGenerator.readInt("Enter size of the array:"));
            System.out.println(testObj);
        }
    }
}

你不需要将Comparator传递给minHeap,因为它已经按整数自然顺序升序排序。 - Enrico Giurin

3

一个完全平衡的二叉搜索树(BST)不就是一个中位数堆吗?虽然红黑树并不总是完美平衡的,但对于您的目的来说可能已经足够接近了。而且它保证了log(n)的性能!

AVL树比红黑树更加平衡,因此它们更接近真正的中位数堆。


1
然后,每次操作集合时,您需要维护一个中位数值。由于在BST中检索任意排名的元素需要O(logN)的时间,因此这仍然足够...我知道... - phoeagon
2
是的,但中位数堆将在常数时间内给出中位数。 - Bruce
2
@Bruce:这只是对于BSTs而言的真实情况:一旦你建立了结构,获取中位数(不删除它)的时间复杂度为O(0),但是,如果你删除它,那么你必须重新构建堆/树,这需要O(logn)的时间复杂度。 - angelatlarge
@angelatlarge 我喜欢你的想法。但是,根据您如何定义中位数,它可能比O(0)更昂贵。如果元素数量为偶数并且您将中位数定义为两个中间元素的平均值,则必须找到除根之外的另一个元素。 - Alma Alma

3

这是基于comocomocomocomo的回答编写的代码:

import java.util.PriorityQueue;

public class Median {
private  PriorityQueue<Integer> minHeap = 
    new PriorityQueue<Integer>();
private  PriorityQueue<Integer> maxHeap = 
    new PriorityQueue<Integer>((o1,o2)-> o2-o1);

public float median() {
    int minSize = minHeap.size();
    int maxSize = maxHeap.size();
    if (minSize == 0 && maxSize == 0) {
        return 0;
    }
    if (minSize > maxSize) {
        return minHeap.peek();
    }if (minSize < maxSize) {
        return maxHeap.peek();
    }
    return (minHeap.peek()+maxHeap.peek())/2F;
}

public void insert(int element) {
    float median = median();
    if (element > median) {
        minHeap.offer(element);
    } else {
        maxHeap.offer(element);
    }
    balanceHeap();
}

private void balanceHeap() {
    int minSize = minHeap.size();
    int maxSize = maxHeap.size();
    int tmp = 0;
    if (minSize > maxSize + 1) {
        tmp = minHeap.poll();
        maxHeap.offer(tmp);
    }
    if (maxSize > minSize + 1) {
        tmp = maxHeap.poll();
        minHeap.offer(tmp);
    }
  }
}

2

这里是一个Scala实现,遵循上面comocomocomocomo的想法。

class MedianHeap(val capacity:Int) {
    private val minHeap = new PriorityQueue[Int](capacity / 2)
    private val maxHeap = new PriorityQueue[Int](capacity / 2, new Comparator[Int] {
      override def compare(o1: Int, o2: Int): Int = Integer.compare(o2, o1)
    })

    def add(x: Int): Unit = {
      if (x > median) {
        minHeap.add(x)
      } else {
        maxHeap.add(x)
      }

      // Re-balance the heaps.
      if (minHeap.size - maxHeap.size > 1) {
        maxHeap.add(minHeap.poll())
      }
      if (maxHeap.size - minHeap.size > 1) {
        minHeap.add(maxHeap.poll)
      }
    }

    def median: Double = {
      if (minHeap.isEmpty && maxHeap.isEmpty)
        return Int.MinValue
      if (minHeap.size == maxHeap.size) {
        return (minHeap.peek+ maxHeap.peek) / 2.0
      }
      if (minHeap.size > maxHeap.size) {
        return minHeap.peek()
      }
      maxHeap.peek
    }
  }

2

另一种不使用最大堆和最小堆的方法是直接使用中位数堆。

在最大堆中,父节点大于子节点。我们可以有一种新类型的堆,其中父节点在子节点的“中间”-左子节点小于父节点,右子节点大于父节点。所有偶数项都是左子节点,所有奇数项都是右子节点。

与最大堆中可执行的上浮和下沉操作相同,也可以在此中位数堆中执行这些操作-稍作修改即可。在最大堆中的典型上浮操作中,插入的条目上浮直到它小于父条目,在这里,在中位数堆中,它将上浮直到它小于父亲(如果它是奇数条目)或大于父亲(如果它是偶数条目)。

以下是我对此中位数堆的实现。为简单起见,我使用了一个整数数组。

package priorityQueues;
import edu.princeton.cs.algs4.StdOut;

public class MedianInsertDelete {

    private Integer[] a;
    private int N;

    public MedianInsertDelete(int capacity){

        // accounts for '0' not being used
        this.a = new Integer[capacity+1]; 
        this.N = 0;
    }

    public void insert(int k){

        a[++N] = k;
        swim(N);
    }

    public int delMedian(){

        int median = findMedian();
        exch(1, N--);
        sink(1);
        a[N+1] = null;
        return median;

    }

    public int findMedian(){

        return a[1];


    }

    // entry swims up so that its left child is smaller and right is greater

    private void swim(int k){


        while(even(k) && k>1 && less(k/2,k)){

            exch(k, k/2);

            if ((N > k) && less (k+1, k/2)) exch(k+1, k/2);
            k = k/2;
        }

        while(!even(k) && (k>1 && !less(k/2,k))){

            exch(k, k/2);
            if (!less (k-1, k/2)) exch(k-1, k/2);
            k = k/2;
        }

    }

// if the left child is larger or if the right child is smaller, the entry sinks down
    private void sink (int k){

        while(2*k <= N){
            int j = 2*k;
            if (j < N && less (j, k)) j++;
            if (less(k,j)) break;
            exch(k, j);
            k = j;
        }

    }

    private boolean even(int i){

        if ((i%2) == 0) return true;
        else return false;
    }

    private void exch(int i, int j){

        int temp = a[i];
        a[i] = a[j];
        a[j] = temp;
    }

    private boolean less(int i, int j){

        if (a[i] <= a[j]) return true;
        else return false;
    }


    public static void main(String[] args) {

        MedianInsertDelete medianInsertDelete = new MedianInsertDelete(10);

        for(int i = 1; i <=10; i++){

            medianInsertDelete.insert(i);
        }

        StdOut.println("The median is: " + medianInsertDelete.findMedian());

        medianInsertDelete.delMedian();


        StdOut.println("Original median deleted. The new median is " + medianInsertDelete.findMedian());




    }
}



不幸的是,中位数并不总是以这种方式到达堆的顶部。我认为您只能保证堆的顶部在排序数组中距离中位数不超过log2(n)个跳跃。考虑维护此不变量的以下堆,但没有将中位数放在顶部: 70 17 87 11 93 25 92 2 67 85 95 8 52 0 94 1 6 39 78 3 - salt-die

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接