快速的函数式归并排序

8

这是我在Scala中实现归并排序的代码:

object FuncSort {
  def merge(l: Stream[Int], r: Stream[Int]) : Stream[Int] = {
    (l, r) match {
      case (h #:: t, Empty) => l
      case (Empty, h #:: t) => r
      case (x #:: xs, y #:: ys) => if(x < y ) x #:: merge(xs, r) else y #:: merge(l, ys)
    }
  }

  def sort(xs: Stream[Int]) : Stream[Int] = {
    if(xs.length == 1) xs
    else {
      val m = xs.length / 2
      val (l, r) = xs.splitAt(m)
      merge(sort(l), sort(r))
    }
  }
}

它的功能正确,渐近看起来也很好,但比这里的Java实现http://algs4.cs.princeton.edu/22mergesort/Merge.java.html慢得多(大约10倍),并且使用了大量内存。是否有更快的合并排序实现是函数式的?显然,可以逐行移植Java版本,但这不是我要找的。

更新:我已将Stream更改为List,将#::更改为::,排序例程变得更快,只比Java版本慢三到四倍。但我不明白为什么不会因堆栈溢出而崩溃?merge不是尾递归,所有参数都是严格求值的...这怎么可能?


你确定它不会溢出堆栈吗?使用Iterator.continually(Random.nextInt).take(N).toList生成一个任意大的未排序列表。 - Aaron Novstrup
@AaronNovstrup 看起来它确实这样做。 - synapse
3个回答

3
您提出了多个问题。我会按照逻辑顺序回答它们:

Stream 版本没有堆栈溢出

您没有真正问这个问题,但它引发了一些有趣的观察。

在 Stream 版本中,您在 merge 函数内使用了 #:: merge(...)。通常,这将是一个递归调用,并可能导致对于足够大的输入数据而言堆栈溢出。但在这种情况下不会发生。运算符 #::(a,b)class ConsWrapper[A] 中实现(有一个隐式转换),并且是 cons.apply[A](hd: A, tl: ⇒ Stream[A]): Cons[A] 的同义词。正如您所看到的,第二个参数是按名称调用,这意味着它是惰性求值的。

这意味着 merge 返回一个新创建的类型为 cons 的对象,该对象最终将再次调用 merge。换句话说:递归不会发生在堆栈上,而是在堆上。通常,您有很多堆。

使用堆进行递归是一种处理非常深的递归的好技术。但它比使用堆栈慢得多。因此,您为递归深度而牺牲了速度。这就是使用 Stream 如此缓慢的主要原因。

第二个原因是,为了获取 Stream 的长度,Scala 必须将整个 Stream 实体化。但在排序 Stream 期间,它无论如何都必须实体化每个元素,因此这不会造成太大的影响。

List 版本没有堆栈溢出

当您将 Stream 更改为 List 时,确实使用堆栈进行递归。现在可能会发生堆栈溢出。但是,在排序中,您通常具有 log(size) 的递归深度,通常是以 2 为底的对数。因此,要对 40 亿个输入项进行排序,您需要大约 32 个堆栈帧。使用默认堆栈大小至少为 320k(在 Windows 上,其他系统具有较大的默认值),这样可以留下足够的递归空间,因此可以对许多输入数据进行排序。

更快的函数实现

这取决于 :-)

您应该使用堆栈,而不是堆进行递归。并且您应该根据输入数据决定您的策略:

  1. 对于小数据块,可以使用一些简单直接的算法就地排序。算法复杂度不会让你感到棘手,并且你可以通过将所有数据存储在缓存中获得更好的性能。当然,你仍然可以针对特定大小手写排序网络
  2. 如果你有数字输入数据,可以使用基数排序并将工作交给你处理器或GPU上的向量单元(更复杂的算法可以在GPU Gems中找到)。
  3. 对于中等大小的数据块,可以使用分而治之的策略将数据分割为多个线程(仅当你有多个内核时!)
  4. 对于巨大的数据块,使用归并排序并将其分成适合内存的块。如果需要,可以将这些块分布在网络上并在内存中进行排序。

不要使用swap操作,使用你的缓存。如果可以,请使用可变数据结构并进行原地排序。我认为,函数式和快速排序之间的结合效果不是很好。要使排序真正快速,你将不得不使用有状态的操作(例如,在可变数组上进行就地归并排序)。

我通常在所有程序中尝试这样做:尽可能地使用纯函数式风格,但对于可行的小部分使用有状态操作(例如,因为它具有更好的性能或代码必须处理大量状态,如果我使用var而不是val,则变得更加易读)。


2
这里有几点需要注意。首先,你没有正确处理初始流为空的情况。你可以通过修改sort内部的初始检查来解决这个问题,改为if(xs.length <= 1) xs。其次,流可能具有无法计算的长度(例如Stream.from(1)),这在尝试计算其一半长度时会出现问题(潜在的无限长度)。你可能需要考虑使用hasDefiniteSize或类似的方法进行检查(尽管使用不当可能会过滤掉一些本应该可计算的流)。最后,这个函数被定义为在流上操作可能会导致速度变慢。我尝试对你的流版本的归并排序和一个用于处理列表的版本进行了大量运行时间的比较,结果表明,列表版本大约快了3倍(尽管只是单次运行)。这表明,在这种方式下,流与列表或其他序列类型相比效率更低(使用向量可能仍然更快,或者像Java解决方案中引用的那样使用数组)。话虽如此,我对时间和效率不是很熟悉,因此其他人可能能够给出更具有专业知识的答复。

目前我不关心无限流,显然它们应该使用自底向上版本的算法进行处理。 - synapse
好的 - 虽然我一时半会儿也不确定这样一个自下而上的版本会是什么样子(我很想看看)。 - Shadowlands
1
将数字列表拆分为单例列表,然后合并它们。 - synapse
@synapse 你可能不在意无限流,但是你的代码没有防止它们被传递,一旦遇到就会崩溃。当使用有限流时,意外创建无限流是非常容易的。 - itsbruce

1
你的实现是自顶向下的归并排序。我发现自底向上的归并排序更快,并且与List.sorted(对于我的测试用例,随机大小的随机数字列表)相当。
def bottomUpMergeSort[A](la: List[A])(implicit ord: Ordering[A]): List[A] = {
  val l = la.length

  @scala.annotation.tailrec
  def merge(l: List[A], r: List[A], acc: List[A] = Nil): List[A] = (l, r) match {
    case (Nil, Nil)           => acc
    case (Nil, h :: t)        => merge(Nil, t, h :: acc)
    case (h :: t, Nil)        => merge(t, Nil, h :: acc)
    case (lh :: lt, rh :: rt) =>
      if(ord.lt(lh, rh)) merge(lt, r, lh :: acc)
      else               merge(l, rt, rh :: acc)
  }

  @scala.annotation.tailrec
  def process(la: List[A], h: Int, acc: List[A] = Nil): List[A] = {
    if(la == Nil) acc.reverse
    else {
      val (l1, r1) = la.splitAt(h)
      val (l2, r2) = r1.splitAt(h)

      process(r2, h, merge(l1, l2, acc))
    }
  }

  @scala.annotation.tailrec
  def run(la: List[A], h: Int): List[A] =
    if(h >= l) la
    else       run(process(la, h), h * 2)

  run(la, 1)
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接