Scala中位数实现

39

如何在scala中快速实现中位数?

这是我在rosetta code上找到的:

  def median(s: Seq[Double])  =
  {
    val (lower, upper) = s.sortWith(_<_).splitAt(s.size / 2)
    if (s.size % 2 == 0) (lower.last + upper.head) / 2.0 else upper.head
  }

我不喜欢这个方法,因为它会进行排序。我知道有一些方法可以在线性时间内计算中位数。

编辑:

我希望拥有一组中位数函数,以便在各种情况下使用:

  1. 快速、原地计算中位数,可在线性时间内完成
  2. 适用于流式遍历的中位数,但是您只能在内存中保留 O(log n) 个值 就像这样
  3. 适用于流式遍历的中位数,在内存中最多只能保留 O(log n) 个值,并且您最多只能遍历一次流(这可能吗?)

请仅发布可以 编译并正确计算中位数 的代码。为简单起见,您可以假设所有输入都包含奇数个值。


快速谷歌搜索给了我这个这个。基本上,你要找的是选择算法。Scala版本留给读者作为练习。 - Taylor Leese
1
一个“好”的算法要复杂得多。请搜索“中位数的中位数”或“五个数的中位数”。 - Landei
一个实现良好的(即库)排序算法,在您应用程序的实际情况下可能比某些据称具有线性时间复杂度的算法实现更快。至于上面的代码,根据您假设的 Seq 实现类型,您可以省略 split 并改为使用索引访问。 - Raphael
我认为第三种情况是不可能的。比如说,我得到了1000到1500之间的数字。中位数是1250。现在,如果我开始得到小于1000的数字,中位数将减少一个,直到达到1000。同样地,如果我开始得到大于1500的数字,中位数将增加,直到达到1500。因此,你需要保留到目前为止看到的所有数字。 - Daniel C. Sobral
1个回答

64

不可变算法

第一个算法Taylor Leese指示,虽然是二次的,但平均线性。然而,这取决于枢轴选择。因此,我在这里提供了一个具有可插拔枢轴选择的版本,包括随机枢轴和中位数枢轴(保证线性时间)。

import scala.annotation.tailrec

@tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double = {
    val a = choosePivot(arr)
    val (s, b) = arr partition (a >)
    if (s.size == k) a
    // The following test is used to avoid infinite repetition
    else if (s.isEmpty) {
        val (s, b) = arr partition (a ==)
        if (s.size > k) a
        else findKMedian(b, k - s.size)
    } else if (s.size < k) findKMedian(b, k - s.size)
    else findKMedian(s, k)
}

def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2)

随机枢轴(二次,线性平均),不可变

这是随机枢轴选择。带有随机因素的算法分析比普通算法更加棘手,因为它主要涉及概率和统计学。

def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size))

中位数的中位数(线性时间),不可变

中位数的中位数方法保证了与上述算法一起使用时具有线性时间。首先,需要一个计算最多5个数字的中位数的算法,这是中位数的中位数算法的基础。Rex Kerrthis answer中提供了此算法-该算法非常依赖于其速度。

def medianUpTo5(five: Array[Double]): Double = {
  def order2(a: Array[Double], i: Int, j: Int) = {
    if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t }
  }

  def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) = {
    if (a(i)<a(k)) { order2(a,j,k); a(j) }
    else { order2(a,i,l); a(i) }
  }

  if (five.length < 2) return five(0)
  order2(five,0,1)
  if (five.length < 4) return (
    if (five.length==2 || five(2) < five(0)) five(0)
    else if (five(2) > five(1)) five(1)
    else five(2)
  )
  order2(five,2,3)
  if (five.length < 5) pairs(five,0,1,2,3)
  else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) }
  else { order2(five,3,4); pairs(five,0,1,3,4) }
}

而且,接下来是中位数算法本身。基本上,它保证所选择的枢轴将大于至少30%并小于列表的其他30%,这足以保证先前算法的线性性。有关详细信息,请查看另一个答案中提供的维基百科链接。
def medianOfMedians(arr: Array[Double]): Double = {
    val medians = arr grouped 5 map medianUpTo5 toArray;
    if (medians.size <= 5) medianUpTo5 (medians)
    else medianOfMedians(medians)
}

原地算法

这里是一个原地算法的版本。我使用了一个实现了原地分区的类,其中包含一个支持数组,以便对算法进行最小化的更改。

case class ArrayView(arr: Array[Double], from: Int, until: Int) {
    def apply(n: Int) = 
        if (from + n < until) arr(from + n)
        else throw new ArrayIndexOutOfBoundsException(n)

    def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) = {
      var upper = until - 1
      var lower = from
      while (lower < upper) {
        while (lower < until && p(arr(lower))) lower += 1
        while (upper >= from && !p(arr(upper))) upper -= 1
        if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp }
      }
      (copy(until = lower), copy(from = lower))
    }

    def size = until - from
    def isEmpty = size <= 0

    override def toString = arr mkString ("ArraySize(", ", ", ")")
}; object ArrayView {
    def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size)
}

@tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double = {
    val a = choosePivot(arr)
    val (s, b) = arr partitionInPlace (a >)
    if (s.size == k) a
    // The following test is used to avoid infinite repetition
    else if (s.isEmpty) {
        val (s, b) = arr partitionInPlace (a ==)
        if (s.size > k) a
        else findKMedianInPlace(b, k - s.size)
    } else if (s.size < k) findKMedianInPlace(b, k - s.size)
    else findKMedianInPlace(s, k)
}

def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2)

随机轴点,原地排序

我只为原地排序算法实现了随机轴点,因为中位数的中位数需要比我目前定义的ArrayView类提供更多的支持。

def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size))

直方图算法(O(log(n))内存),不可变

关于流,如果只能遍历一次,那么无法做到比O(n)更少的内存占用,除非你知道字符串长度(在这种情况下,它在我的书中就不再是流了)。

使用桶也有一些问题,但如果我们可以多次遍历它,那么我们就可以知道其大小、最大值和最小值,并从那里开始工作。例如:

def findMedianHistogram(s: Traversable[Double]) = {
    def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double = {
        // The buckets
        def numberOfBuckets = (math.log(s.size).toInt + 1) max 2
        val buckets = new Array[Int](numberOfBuckets)

        // The upper limit of each bucket
        val max = s.max
        val min = s.min
        val increment = (max - min) / numberOfBuckets
        val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _)

        // Return the bucket a number is supposed to be in
        def bucketIndex(d: Double) = indices indexWhere (d <=)

        // Compute how many in each bucket
        s foreach { d => buckets(bucketIndex(d)) += 1 }

        // Now make the buckets cumulative
        val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1)

        // The bucket where our target is at
        val medianBucket = partialTotals indexWhere (medianIndex <)

        // Keep track of how many numbers there are that are less 
        // than the median bucket
        val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1)

        // Test whether a number is in the median bucket
        def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket

        // Get a view of the target bucket
        val view = s.view filter insideMedianBucket

        // If all numbers in the bucket are equal, return that
        if (view forall (view.head ==)) view.head
        // Otherwise, recurse on that bucket
        else medianHistogram(view, newDiscarded, medianIndex)
    }

    medianHistogram(s, 0, (s.size - 1) / 2)
}

测试和基准

为了测试算法,我使用Scalacheck,并将每个算法的输出与带有排序的平凡实现的输出进行比较。当然,这假设排序版本是正确的。

我正在使用所有提供的轴点选择以及固定轴点选择(数组的一半,向下舍入)来对上述算法进行基准测试。每个算法都使用三种不同的输入数组大小进行测试,并针对每个大小进行三次测试。

以下是测试代码:

import org.scalacheck.{Prop, Pretty, Test}
import Prop._
import Pretty._

def test(algorithm: Array[Double] => Double, 
         reference: Array[Double] => Double): String = {
    def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")
    val resultEqualsReference = forAll { (arr: Array[Double]) => 
        arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)
    }
    Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0))
}

import java.lang.System.currentTimeMillis

def bench[A](n: Int)(body: => A): Long = {
  val start = currentTimeMillis()
  1 to n foreach { _ => body }
  currentTimeMillis() - start
}

import scala.util.Random.nextDouble

def benchmark(algorithm: Array[Double] => Double,
              arraySizes: List[Int]): List[Iterable[Long]] = 
    for (size <- arraySizes)
    yield for (iteration <- 1 to 3)
        yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))

def testAndBenchmark: String = {
    val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(
        "Random Pivot"      -> chooseRandomPivot,
        "Median of Medians" -> medianOfMedians,
        "Midpoint"          -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))
    )
    val inPlacePivotSelection: List[(String, ArrayView => Double)] = List(
        "Random Pivot (in-place)" -> chooseRandomPivotInPlace,
        "Midpoint (in-place)"     -> ((arr: ArrayView) => arr((arr.size - 1) / 2))
    )
    val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)
        yield name -> (findMedian(_: Array[Double])(pivotSelection))
    val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection)
        yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection))
    val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr))
    val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2))
    val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms

    val formattingString = "%%-%ds  %%s" format (algorithms map (_._1.length) max)

    // Tests
    val testResults = for ((name, algorithm) <- algorithms)
        yield formattingString format (name, test(algorithm, sortingAlgorithm._2))

    // Benchmarks
    val arraySizes = List(100, 500, 1000)
    def formatResults(results: List[Long]) = results map ("%8d" format _) mkString

    val benchmarkResults: List[String] = for {
        (name, algorithm) <- algorithms
        results <- benchmark(algorithm, arraySizes).transpose
    } yield formattingString format (name, formatResults(results))

    val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong)))

    "Tests" :: "*****" :: testResults ::: 
    ("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "\n", "\n")
}

结果

测试:

Tests
*****
Sorting                OK, passed 100 tests.
Histogram              OK, passed 100 tests.
Random Pivot           OK, passed 100 tests.
Median of Medians      OK, passed 100 tests.
Midpoint               OK, passed 100 tests.
Random Pivot (in-place)OK, passed 100 tests.
Midpoint (in-place)    OK, passed 100 tests.

基准测试:

Benchmark
*********
Algorithm                   100     500    1000
Sorting                    1038    6230   14034
Sorting                    1037    6223   13777
Sorting                    1039    6220   13785
Histogram                  2918   11065   21590
Histogram                  2596   11046   21486
Histogram                  2592   11044   21606
Random Pivot                904    4330    8622
Random Pivot                902    4323    8815
Random Pivot                896    4348    8767
Median of Medians          3591   16857   33307
Median of Medians          3530   16872   33321
Median of Medians          3517   16793   33358
Midpoint                   1003    4672    9236
Midpoint                   1010    4755    9157
Midpoint                   1017    4663    9166
Random Pivot (in-place)     392    1746    3430
Random Pivot (in-place)     386    1747    3424
Random Pivot (in-place)     386    1751    3431
Midpoint (in-place)         378    1735    3405
Midpoint (in-place)         377    1740    3408
Midpoint (in-place)         375    1736    3408

分析

除了排序版本,所有算法的结果都与平均线性时间复杂度兼容。

中位数算法可以保证在最坏情况下具有线性时间复杂度,但比随机枢轴要慢得多。

固定枢轴选择略逊于随机枢轴,但在非随机输入时可能表现更差。

原地版本大约快230%〜250%,但进一步的测试(未显示)似乎表明这种优势随数组大小增加而增加。

直方图算法让我非常惊讶。它展示了平均线性时间复杂度,并且比中位数算法快33%。然而,输入确实是随机的。最坏情况是二次的 - 我在调试代码时看到了一些例子。


这段代码存在三个问题:(a)它无法编译(递归函数需要显式的返回类型),(b)它不是线性时间(因为分区是O(n),而且运行了O(n)次),(c)它产生了错误的答案。除此之外,没什么问题了。 - Michael Lorton
1
@Malvolio 这个算法的时间复杂度看起来是O(nlogn),因为数组Arr的大小每次平均会减半。然而,这种分析是肤浅的。该算法看起来非常像快速排序,但只递归了一半的分区,这使它比快速排序更快。此外,它不需要降到1大小的分区。至于错误,它们大多与原始算法隐式地从分区中删除“a”以及在声明“a”时缺少“arr”有关。Off by one的错误很糟糕。 - Daniel C. Sobral
1
@Malvolio,随机数并不是“毫无意义”的。如果你不使用随机数,有人可能会猜测你正在使用的策略,在程序需要O(n^2)时间的情况下选择一个例子,并挂掉你的服务器。该算法在平均情况下是正确和线性的。 - adamax
1
@Raphael 这个论点是完全有效的。假设每次数组长度减少一半。那么第一次迭代需要n个单位的时间,第二次需要n/2个单位的时间,第三次需要n/4个单位的时间,以此类推,总共需要的时间为n+n/2+n/4 + ... = 2*n。当然这只是一个直观的解释,严格的证明可以在任何算法书中找到。 - adamax
@dsg 这个算法因为随机选择,可以给你平均线性时间。使用随机化分析算法更加复杂和不直观。甚至可能出现这个算法的摊销线性复杂度。维基百科提到,没有支撑引用的情况下,这个随机算法在平均情况下比完全线性的算法更好。所以..你还想要什么? - Daniel C. Sobral
显示剩余18条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接