为什么我的Scala尾递归比while循环更快?

35
以下是Cay Horstmann的《Scala快速入门》中练习4.9的两个解决方案:“编写一个函数lteqgt(values:Array [Int],v:Int),返回一个三元组,其中包含小于v的值的计数,等于v的值的计数和大于v的值的计数。” 其中一个使用尾递归,另一个使用while循环。我以为两者都会编译成类似的字节码,但while循环比尾递归慢近2倍。这让我觉得我的while方法写得很糟糕。
import scala.annotation.tailrec
import scala.util.Random
object PerformanceTest {

  def main(args: Array[String]): Unit = {
    val bigArray:Array[Int] = fillArray(new Array[Int](100000000))
    println(time(lteqgt(bigArray, 25)))
    println(time(lteqgt2(bigArray, 25)))
  }

  def time[T](block : => T):T = {
    val start = System.nanoTime : Double
    val result = block
    val end = System.nanoTime : Double
    println("Time = " + (end - start) / 1000000.0 + " millis")
    result
  }

  @tailrec def fillArray(a:Array[Int], pos:Int=0):Array[Int] = {
    if (pos == a.length)
      a
    else {
      a(pos) = Random.nextInt(50)
      fillArray(a, pos+1)
    }
  }

  @tailrec def lteqgt(values: Array[Int], v:Int, lt:Int=0, eq:Int=0, gt:Int=0, pos:Int=0):(Int, Int, Int) = {
    if (pos == values.length)
      (lt, eq, gt)
    else
      lteqgt(values, v, lt + (if (values(pos) < v) 1 else 0), eq + (if (values(pos) == v) 1 else 0), gt + (if (values(pos) > v) 1 else 0), pos+1) 
  }

  def lteqgt2(values:Array[Int], v:Int):(Int, Int, Int) = {
    var lt = 0
    var eq = 0
    var gt = 0
    var pos = 0
    val limit = values.length
    while (pos < limit) {
      if (values(pos) > v)
        gt += 1
      else if (values(pos) < v)
        lt += 1
      else
        eq += 1
      pos += 1
    }
    (lt, eq, gt)
  }
}

根据您的堆大小调整bigArray的大小。以下是一些示例输出:

Time = 245.110899 millis
(50004367,2003090,47992543)
Time = 465.836894 millis
(50004367,2003090,47992543)

为什么while方法比tailrec方法慢很多?表面上看,tailrec版本似乎处于劣势,因为它必须在每次迭代中始终执行3个“if”检查,而while版本由于else结构的存在通常只需要执行1或2个测试。(注意,反转我执行这两种方法的顺序不会影响结果)。

1
我经常对此感到好奇。答案肯定在JIT中。禁用JIT后重复基准测试会很有趣。 - Daniel C. Sobral
请查看 https://dev59.com/ZWUp5IYBdhLWcg3wo4yd#48143130 的结果,其中 while 循环比尾递归更快(使用 scala 2.12.x 和 scalameter 基准测试来尝试管理 JVM 不一致性)。 - Darren Weber
2个回答

37

测试结果(将数组大小缩小到20000000后)

在Java 1.6.22下,尾递归和while循环分别需要 151毫秒122毫秒

在Java 1.7.0下,需要 55毫秒101毫秒

所以,在Java 6中,while循环实际上更快;在Java 7中,两者的性能都有所提高,但尾递归版本已经超过了while循环。

解释

性能差异是由于在循环中,您会根据条件向总数添加1,而对于递归,您始终添加1或0。因此它们并不相等。与您的递归方法等效的while循环是:

  def lteqgt2(values:Array[Int], v:Int):(Int, Int, Int) = {
    var lt = 0
    var eq = 0
    var gt = 0
    var pos = 0
    val limit = values.length
    while (pos < limit) {
      gt += (if (values(pos) > v) 1 else 0)
      lt += (if (values(pos) < v) 1 else 0)
      eq += (if (values(pos) == v) 1 else 0)
      pos += 1
    }
    (lt, eq, gt)
  }

这样做与递归方法完全相同(不管是哪个Java版本),执行时间也相同。

讨论

我不是Java 7 VM(HotSpot)为什么能比第一个版本更好地优化此代码的专家,但我猜测是因为它每次都经过相同的代码路径(而不是在if / else if路径上分支),所以字节码可以更有效地内联。

但请记住,在Java 6中不是这种情况。为什么一个while循环比另一个更快是JVM内部问题。对于Scala程序员来说,从惯用尾递归产生的版本是最新JVM中更快的版本。

差异也可能发生在处理器级别。参见这个问题,它解释了如果代码包含不可预测的分支会使代码变慢。


1
好的,谢谢,我使用那个版本也得到了相同的性能结果。因此,只要我在两种情况下正确编写等效的主体,尾递归和while循环结构可能会编译成非常相似的字节码。if / else语句方面有一个有趣的效果。 - waifnstray

24
这两种结构并不相同。特别地,在第一种情况下,你不需要任何跳转(在x86上,你可以使用cmp和setle和add,而不必使用cmp和jb和(如果你不跳转)add。不跳转比跳转快,在几乎所有现代架构上都如此。
因此,如果你的代码像这样:
if (a < b) x += 1

你可以选择添加,也可以选择跳过。

x += (a < b)

(这只在C/C++中有意义,其中1 = true,0 = false),后者往往更快,因为它可以转换为更紧凑的汇编代码。 在Scala/Java中,你不能这样做,但你可以

x += if (a < b) 1 else 0

一款聪明的 JVM 应该能够识别它与 x += (a < b) 相同,这有一个无需跳转的机器码转换,通常比跳转更快。更聪明的 JVM 应该能够识别到

if (a < b) x += 1

又一次相同了(因为添加零不会改变任何东西)。

C/C++编译器通常执行此类优化。不能应用任何这些优化并不是JIT编译器的优点;很明显,自1.7版本以来它已经可以进行部分优化(即它无法认识到加零与有条件地加一相同,但至少将x += if (a<b) 1 else 0转换成了快速机器代码)。

现在,所有这些都与尾递归或while循环本身无关。使用尾递归更自然的方式是编写if (a < b) 1 else 0形式,但您可以选择任何一种方式;而对于while循环,您也可以选择任何一种方式。只是碰巧您为尾递归选择了一种形式,为while循环选择了另一种形式,使其看起来就像递归与循环的区别一样,实际上这两种形式是不同的条件语句的写法。


很抱歉,你的回答细节超出了我的理解范围,但听起来似乎结论是尾递归应该作为一种编程风格优先于while循环(如果编译器支持),并且在Scala中,尾递归可能(未来或现在)比while循环运行更快。这正确吗? - waifnstray
@waifnstray - 不,那不是重点。让我编辑一下以便更清晰明了。 - Rex Kerr
明白了,谢谢。我误解了你所指的两个结构。 - waifnstray
无跳转编程在游戏开发中尤为常见。另外,如果你想知道的话,将增量从1改为10来验证这个答案。 - Seth

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接