不可变算法
第一个算法由Taylor Leese指示,虽然是二次的,但平均线性。然而,这取决于枢轴选择。因此,我在这里提供了一个具有可插拔枢轴选择的版本,包括随机枢轴和中位数枢轴(保证线性时间)。
import scala.annotation.tailrec
@tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double = {
val a = choosePivot(arr)
val (s, b) = arr partition (a >)
if (s.size == k) a
else if (s.isEmpty) {
val (s, b) = arr partition (a ==)
if (s.size > k) a
else findKMedian(b, k - s.size)
} else if (s.size < k) findKMedian(b, k - s.size)
else findKMedian(s, k)
}
def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2)
随机枢轴(二次,线性平均),不可变
这是随机枢轴选择。带有随机因素的算法分析比普通算法更加棘手,因为它主要涉及概率和统计学。
def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size))
中位数的中位数(线性时间),不可变
中位数的中位数方法保证了与上述算法一起使用时具有线性时间。首先,需要一个计算最多5个数字的中位数的算法,这是中位数的中位数算法的基础。Rex Kerr在this answer中提供了此算法-该算法非常依赖于其速度。
def medianUpTo5(five: Array[Double]): Double = {
def order2(a: Array[Double], i: Int, j: Int) = {
if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t }
}
def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) = {
if (a(i)<a(k)) { order2(a,j,k); a(j) }
else { order2(a,i,l); a(i) }
}
if (five.length < 2) return five(0)
order2(five,0,1)
if (five.length < 4) return (
if (five.length==2 || five(2) < five(0)) five(0)
else if (five(2) > five(1)) five(1)
else five(2)
)
order2(five,2,3)
if (five.length < 5) pairs(five,0,1,2,3)
else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) }
else { order2(five,3,4); pairs(five,0,1,3,4) }
}
而且,接下来是中位数算法本身。基本上,它保证所选择的枢轴将大于至少30%并小于列表的其他30%,这足以保证先前算法的线性性。有关详细信息,请查看另一个答案中提供的维基百科链接。
def medianOfMedians(arr: Array[Double]): Double = {
val medians = arr grouped 5 map medianUpTo5 toArray;
if (medians.size <= 5) medianUpTo5 (medians)
else medianOfMedians(medians)
}
原地算法
这里是一个原地算法的版本。我使用了一个实现了原地分区的类,其中包含一个支持数组,以便对算法进行最小化的更改。
case class ArrayView(arr: Array[Double], from: Int, until: Int) {
def apply(n: Int) =
if (from + n < until) arr(from + n)
else throw new ArrayIndexOutOfBoundsException(n)
def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) = {
var upper = until - 1
var lower = from
while (lower < upper) {
while (lower < until && p(arr(lower))) lower += 1
while (upper >= from && !p(arr(upper))) upper -= 1
if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp }
}
(copy(until = lower), copy(from = lower))
}
def size = until - from
def isEmpty = size <= 0
override def toString = arr mkString ("ArraySize(", ", ", ")")
}; object ArrayView {
def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size)
}
@tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double = {
val a = choosePivot(arr)
val (s, b) = arr partitionInPlace (a >)
if (s.size == k) a
else if (s.isEmpty) {
val (s, b) = arr partitionInPlace (a ==)
if (s.size > k) a
else findKMedianInPlace(b, k - s.size)
} else if (s.size < k) findKMedianInPlace(b, k - s.size)
else findKMedianInPlace(s, k)
}
def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2)
随机轴点,原地排序
我只为原地排序算法实现了随机轴点,因为中位数的中位数需要比我目前定义的ArrayView
类提供更多的支持。
def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size))
直方图算法(O(log(n))内存),不可变
关于流,如果只能遍历一次,那么无法做到比O(n)
更少的内存占用,除非你知道字符串长度(在这种情况下,它在我的书中就不再是流了)。
使用桶也有一些问题,但如果我们可以多次遍历它,那么我们就可以知道其大小、最大值和最小值,并从那里开始工作。例如:
def findMedianHistogram(s: Traversable[Double]) = {
def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double = {
def numberOfBuckets = (math.log(s.size).toInt + 1) max 2
val buckets = new Array[Int](numberOfBuckets)
val max = s.max
val min = s.min
val increment = (max - min) / numberOfBuckets
val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _)
def bucketIndex(d: Double) = indices indexWhere (d <=)
s foreach { d => buckets(bucketIndex(d)) += 1 }
val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1)
val medianBucket = partialTotals indexWhere (medianIndex <)
val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1)
def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket
val view = s.view filter insideMedianBucket
if (view forall (view.head ==)) view.head
else medianHistogram(view, newDiscarded, medianIndex)
}
medianHistogram(s, 0, (s.size - 1) / 2)
}
测试和基准
为了测试算法,我使用Scalacheck,并将每个算法的输出与带有排序的平凡实现的输出进行比较。当然,这假设排序版本是正确的。
我正在使用所有提供的轴点选择以及固定轴点选择(数组的一半,向下舍入)来对上述算法进行基准测试。每个算法都使用三种不同的输入数组大小进行测试,并针对每个大小进行三次测试。
以下是测试代码:
import org.scalacheck.{Prop, Pretty, Test}
import Prop._
import Pretty._
def test(algorithm: Array[Double] => Double,
reference: Array[Double] => Double): String = {
def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")
val resultEqualsReference = forAll { (arr: Array[Double]) =>
arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)
}
Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0))
}
import java.lang.System.currentTimeMillis
def bench[A](n: Int)(body: => A): Long = {
val start = currentTimeMillis()
1 to n foreach { _ => body }
currentTimeMillis() - start
}
import scala.util.Random.nextDouble
def benchmark(algorithm: Array[Double] => Double,
arraySizes: List[Int]): List[Iterable[Long]] =
for (size <- arraySizes)
yield for (iteration <- 1 to 3)
yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))
def testAndBenchmark: String = {
val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(
"Random Pivot" -> chooseRandomPivot,
"Median of Medians" -> medianOfMedians,
"Midpoint" -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))
)
val inPlacePivotSelection: List[(String, ArrayView => Double)] = List(
"Random Pivot (in-place)" -> chooseRandomPivotInPlace,
"Midpoint (in-place)" -> ((arr: ArrayView) => arr((arr.size - 1) / 2))
)
val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)
yield name -> (findMedian(_: Array[Double])(pivotSelection))
val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection)
yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection))
val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr))
val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2))
val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms
val formattingString = "%%-%ds %%s" format (algorithms map (_._1.length) max)
val testResults = for ((name, algorithm) <- algorithms)
yield formattingString format (name, test(algorithm, sortingAlgorithm._2))
val arraySizes = List(100, 500, 1000)
def formatResults(results: List[Long]) = results map ("%8d" format _) mkString
val benchmarkResults: List[String] = for {
(name, algorithm) <- algorithms
results <- benchmark(algorithm, arraySizes).transpose
} yield formattingString format (name, formatResults(results))
val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong)))
"Tests" :: "*****" :: testResults :::
("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "\n", "\n")
}
结果
测试:
Tests
*****
Sorting OK, passed 100 tests.
Histogram OK, passed 100 tests.
Random Pivot OK, passed 100 tests.
Median of Medians OK, passed 100 tests.
Midpoint OK, passed 100 tests.
Random Pivot (in-place)OK, passed 100 tests.
Midpoint (in-place) OK, passed 100 tests.
基准测试:
Benchmark
*********
Algorithm 100 500 1000
Sorting 1038 6230 14034
Sorting 1037 6223 13777
Sorting 1039 6220 13785
Histogram 2918 11065 21590
Histogram 2596 11046 21486
Histogram 2592 11044 21606
Random Pivot 904 4330 8622
Random Pivot 902 4323 8815
Random Pivot 896 4348 8767
Median of Medians 3591 16857 33307
Median of Medians 3530 16872 33321
Median of Medians 3517 16793 33358
Midpoint 1003 4672 9236
Midpoint 1010 4755 9157
Midpoint 1017 4663 9166
Random Pivot (in-place) 392 1746 3430
Random Pivot (in-place) 386 1747 3424
Random Pivot (in-place) 386 1751 3431
Midpoint (in-place) 378 1735 3405
Midpoint (in-place) 377 1740 3408
Midpoint (in-place) 375 1736 3408
分析
除了排序版本,所有算法的结果都与平均线性时间复杂度兼容。
中位数算法可以保证在最坏情况下具有线性时间复杂度,但比随机枢轴要慢得多。
固定枢轴选择略逊于随机枢轴,但在非随机输入时可能表现更差。
原地版本大约快230%〜250%,但进一步的测试(未显示)似乎表明这种优势随数组大小增加而增加。
直方图算法让我非常惊讶。它展示了平均线性时间复杂度,并且比中位数算法快33%。然而,输入确实是随机的。最坏情况是二次的 - 我在调试代码时看到了一些例子。