为什么Scala的尾递归比Java的慢?

6

使用尾递归的Scala代码进行简单加法

def add(list : List[Int],sum:Int):Int = {
    //Thread.dumpStack()
    if (list.isEmpty) {
        sum
    } else {
        val headVal  = list.head
        add(list.tail, sum + headVal)
    }
}

以下是递归模式下的Java代码,用于加法运算。
public static int add(List<Integer> list, Integer sum) {
    // Thread.dumpStack();
    if (list.isEmpty()) {
        return sum;
    } else {
        int headVal = list.remove(0);
        return add(list, sum + headVal);
    }
}

Java版本运行速度至少快10倍。 运行了1000个条目的测试。在 System.nanoTime() API之前和之后使用API来测量时间。

Scala版本为2.10,Java版本为Java 7。两个版本使用相同的JVM属性。


1
我已经了解到性能测量对于像我这样的有些经验的程序员来说是有些棘手和充满陷阱的。不应低估JIT。你能分享一下你是如何测量性能的细节吗? - Christian Hujer
1
你可能想要进行以下操作:1. 检查字节码,2. 使用“-XX:+PrintCompilation”等参数运行JVM并查看情况;也许JIT能够优化Java生成的字节码,但不能优化Scala生成的字节码。 - fge
如果我取消注释Thread.dumpStack,我可以在Scala堆栈帧中只看到add方法被添加了一次。而在Java中,add方法会逐步添加到堆栈帧中。Java甚至在10K次迭代后会抛出StackOverflow异常。但是对于1K次迭代,Java的运行速度比Scala快10倍。 - RockSolid
5
你只运行了一次对1000个元素的列表求和的程序?这样做没有什么值得一提的测量结果。至少要在开始测量之前运行10000次求和操作;然后再运行10000次并进行测量,最后取平均值。 - Marko Topolnik
你可能应该使用JMH进行微基准测试,它可以处理JIT预热和时钟粒度等问题。 - the8472
4个回答

6
首先,你展示的Scala方法add不在上下文(类)中。如果你在一个类中有这个方法,尾递归优化将无法应用,因为该方法既不是final也不是private。如果添加@tailrec,编译将失败。如果我使用10000运行它,会导致堆栈溢出。
至于Java版本:Java版本使用了一个可变的List:头/尾分解改变了底层列表。因此,在求和之后,您不能再使用该列表,因为它为空。
此外,Scala中的List与Java List完全不同;Scala列表用于头/尾分解。据我所知,java.util.List没有tail方法,Java编译器也不会应用tailrec优化,因此比较是“不公平”的。
无论如何,我已经在不同的场景下运行了一些基于JMH的测试。
你真正可以比较的只有“Scala while”和“Java for”两种情况。它们既不使用OOP也不使用函数式编程,只是命令式的。
五种不同Scala场景的结果
(请向右滚动,最后一列有一个小描述)
Benchmark                   Mode   Samples         Mean   Mean error    Units
a.b.s.b.Benchmark5a.run    thrpt        10      238,515        7,769   ops/ms Like in the question, but with @tailrec
a.b.s.b.Benchmark5b.run    thrpt        10      130,202        2,232   ops/ms Using List.sum
a.b.s.b.Benchmark5c.run    thrpt        10     2756,132       29,920   ops/ms while (no List, vars, imperative)
a.b.s.b.Benchmark5d.run    thrpt        10      237,286        2,203   ops/ms tailrec version with pattern matching
a.b.s.b.Benchmark5e.run    thrpt        10      214,719        2,483   ops/ms Like in the question (= no tailrec opt)
  • 5a和5e与问题中的相似; 5a使用@tailrec
  • 5b:List.sum:非常慢。
  • 5c:不是很惊讶,命令式版本是最快的。
  • 5d使用模式匹配而不是if(那将是“我的风格”),我正在添加这个的来源:
package app.benchmark.scala.benchmark5

import scala.annotation._
import org.openjdk.jmh.annotations.GenerateMicroBenchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import org.openjdk.jmh.runner.Runner
import org.openjdk.jmh.runner.RunnerException
import org.openjdk.jmh.runner.options.Options
import org.openjdk.jmh.runner.options.OptionsBuilder

@State(Scope.Benchmark)
object BenchmarkState5d {
  val list = List.range(1, 1000)
}

class Benchmark5d {
  private def add(list : List[Int]): Int = {
    @tailrec
    def add(list : List[Int], sum: Int): Int = {
      list match {
        case Nil =>
          sum
        case h :: t =>
          add(t, h + sum)
      }
    }
    add(list, 0)
  }

  @GenerateMicroBenchmark
  def run() = {
    add(BenchmarkState5d.list)
  }
}

三种Java场景

Benchmark                   Mode   Samples         Mean   Mean error    Units
a.b.j.b.Benchmark5a.run    thrpt        10       40,437        0,532   ops/ms mutable (rebuilds the list in each iteration)
a.b.j.b.Benchmark5b.run    thrpt        10        0,450        0,008   ops/ms subList
a.b.j.b.Benchmark5c.run    thrpt        10     2735,951       29,177   ops/ms for

如果你真的想比较函数式编程风格(即不可变性,尾递归,头/尾分解)的含义,那么Java版本要慢五倍。
正如Marko Topolnik在评论中指出的:
subList不会复制尾部,但如果应用于LinkedList,则会执行类似的坏操作:它对原始列表进行包装,并使用偏移量来适应语义。结果是O(n)递归算法变为了O(n2),就像尾部反复复制一样。此外,包装器增多,最终会导致列表被包裹一千次。绝对不能与头/尾列表相媲美。
public class Benchmark5a {
    public static int add(List<Integer> list, Integer sum) {
        if (list.isEmpty()) {
            return sum;
        } else {
            int headVal = list.remove(0);
            return add(list, sum + headVal);
        }
    }

    @GenerateMicroBenchmark
    public long run() {
        final List<Integer> list = new LinkedList<Integer>();
        for(int i = 0; i < 1000; i++) {
            list.add(i);
        }
        return add(list, 0);
    }

    public static void main(String[] args) {
        System.out.println(new Benchmark5a().run());
    }
}

@State(Scope.Benchmark)
class BenchmarkState5b {
    public final static List<Integer> list = new LinkedList<Integer>();

    static {
        for(int i = 0; i < 1000; i++) {
            list.add(i);
        }
    }
}

public class Benchmark5b {
    public static int add(List<Integer> list, int sum) {
        if (list.isEmpty()) {
            return sum;
        } else {
            int headVal = list.get(0);
            return add(list.subList(1, list.size()), sum + headVal);
        }
    }

    @GenerateMicroBenchmark
    public long run() {
        return add(BenchmarkState5b.list, 0);
    }

    public static void main(String[] args) {
        System.out.println(new Benchmark5b().run());
    }
}

Scala详细结果

(所有结果仅显示最后一个场景和总体结果)

[...]

# VM invoker: /home/oracle-jdk-1.8-8u40/data/oracle-jdk-1.8.0_40/jre/bin/java
# VM options: <none>
# Fork: 1 of 1
# Warmup: 3 iterations, 1 s each
# Measurement: 10 iterations, 1 s each
# Threads: 1 thread, will synchronize iterations
# Benchmark mode: Throughput, ops/time
# Benchmark: app.benchmark.scala.benchmark5.Benchmark5e.run
# Warmup Iteration   1: 166,153 ops/ms
# Warmup Iteration   2: 215,242 ops/ms
# Warmup Iteration   3: 216,632 ops/ms
Iteration   1: 215,526 ops/ms
Iteration   2: 213,720 ops/ms
Iteration   3: 213,967 ops/ms
Iteration   4: 215,468 ops/ms
Iteration   5: 216,247 ops/ms
Iteration   6: 217,514 ops/ms
Iteration   7: 215,503 ops/ms
Iteration   8: 211,969 ops/ms
Iteration   9: 212,989 ops/ms
Iteration  10: 214,291 ops/ms

Result : 214,719 ±(99.9%) 2,483 ops/ms
  Statistics: (min, avg, max) = (211,969, 214,719, 217,514), stdev = 1,642
  Confidence interval (99.9%): [212,236, 217,202]


Benchmark                   Mode   Samples         Mean   Mean error    Units
a.b.s.b.Benchmark5a.run    thrpt        10      238,515        7,769   ops/ms
a.b.s.b.Benchmark5b.run    thrpt        10      130,202        2,232   ops/ms
a.b.s.b.Benchmark5c.run    thrpt        10     2756,132       29,920   ops/ms
a.b.s.b.Benchmark5d.run    thrpt        10      237,286        2,203   ops/ms
a.b.s.b.Benchmark5e.run    thrpt        10      214,719        2,483   ops/ms

Java详细结果

# VM invoker: /home/oracle-jdk-1.8-8u40/data/oracle-jdk-1.8.0_40/jre/bin/java
# VM options: <none>
# Fork: 1 of 1
# Warmup: 3 iterations, 1 s each
# Measurement: 10 iterations, 1 s each
# Threads: 1 thread, will synchronize iterations
# Benchmark mode: Throughput, ops/time
# Benchmark: app.benchmark.java.benchmark5.Benchmark5c.run
# Warmup Iteration   1: 2777,495 ops/ms
# Warmup Iteration   2: 2888,040 ops/ms
# Warmup Iteration   3: 2692,851 ops/ms
Iteration   1: 2737,169 ops/ms
Iteration   2: 2745,368 ops/ms
Iteration   3: 2754,105 ops/ms
Iteration   4: 2706,131 ops/ms
Iteration   5: 2721,593 ops/ms
Iteration   6: 2769,261 ops/ms
Iteration   7: 2734,461 ops/ms
Iteration   8: 2741,494 ops/ms
Iteration   9: 2740,012 ops/ms
Iteration  10: 2709,915 ops/ms

Result : 2735,951 ±(99.9%) 29,177 ops/ms
  Statistics: (min, avg, max) = (2706,131, 2735,951, 2769,261), stdev = 19,299
  Confidence interval (99.9%): [2706,774, 2765,128]


Benchmark                   Mode   Samples         Mean   Mean error    Units
a.b.j.b.Benchmark5a.run    thrpt        10       40,437        0,532   ops/ms
a.b.j.b.Benchmark5b.run    thrpt        10        0,450        0,008   ops/ms
a.b.j.b.Benchmark5c.run    thrpt        10     2735,951       29,177   ops/ms

更新:添加了另一个使用 ArrayList 的 Java 场景5d。
Benchmark                   Mode   Samples         Mean   Mean error    Units
a.b.j.b.Benchmark5a.run    thrpt        10       34,931        0,504   ops/ms
a.b.j.b.Benchmark5b.run    thrpt        10        0,430        0,005   ops/ms
a.b.j.b.Benchmark5c.run    thrpt        10     2610,085        9,664   ops/ms
a.b.j.b.Benchmark5d.run    thrpt        10       56,693        1,218   ops/ms

subList 不会复制尾部,但是当应用于 LinkedList 时,它会做一些类似糟糕的事情:包装原始列表并使用偏移量来适应语义。结果是 O(n) 递归算法变成了 O(n2)——就像尾部被重复复制一样。此外,包装器会不断增加,最终你会得到一个被包装了一千次的列表。绝对不能与头/尾列表相比。公平的比较方式是构建一个小型自定义链表实现。仅仅替换 ArrayList 也会显示出显著的改进。 - Marko Topolnik
@MarkoTopolnik 感谢您仔细查看结果。我已经在答案中添加了您对subList的评论。使用ArrayList会更好一些(请参见答案中的更新)。我同意“比较”不公平,并希望可以清楚地在答案中阅读到这一点。然而,这些结果是问题的答案。无论如何,即使使用优化的数据结构,仍然没有尾递归优化,那么堆栈仍将是Java的限制。但是,如果速度真的很重要,那么没有什么比命令式的方法更好了。 - Beryllium
通常我建议不要将尾递归看作是一种优化,而是作为语言的一个基本特性,它为您提供了另一种编写普通迭代的语法方式。最终,语义才是最重要的。但请注意,完全的尾递归消除是完全不同的东西,无法用任何其他习惯用语来复制。基本上,这是一种受控制的goto形式。 - Marko Topolnik

3
您无法从这样一个短暂的实验中得出任何有意义的性能结论。您可能遇到了垃圾回收,可能有其他进程占用了您的CPU,所有类可能没有加载完成,而JVM可能正在优化您的代码,而您的测试正在运行。任何这些情况都会对测试结果产生不利影响。
处理1000个元素将非常快。您应该让您的实验时间足够长,使得时间测量的不准确性或外部影响的效果更小。尝试使用一百万个元素,如果仍然只需要几秒钟,请尝试使用1000万个元素。
考虑在JVM启动时运行测试几次,以“预热”JVM。您希望确保已加载了任何惰性加载的类,并且在开始测试之前,JVM执行的任何实时优化都已完成。
使用更长的元素列表再重复实验几次。然后丢弃最快的结果和最慢的结果,并平均剩余结果。

需要消除的主要影响是在代码 JIT 编译之前进行类初始化和测量。顺便说一下,nanoTime 是唯一正确的选择。 - Marko Topolnik
谢谢,我根据你的建议添加了一些关于预热JVM的理由。我还删除了有关nanoTime的部分。我同意这是正确的函数调用。 - bhspencer
@MarkoTopolnik 请参阅http://shipilev.net/blog/2014/nanotrusting-nanotime/。`nanoTime`是标准库中唯一的选择,但仍需小心谨慎。 - Alexey Romanov
@AlexeyRomanov 一篇不错的文章,没错。但我没有看到你的观点:你有没有想到一些替代库可以改进nanoTime?而且在进行微基准测试时,要小心的远不止nanoTime这一个方面。总的来说,我只看到了你评论中的真理。 - Marko Topolnik

0

我尝试了你的代码,Scala版本大约快了5倍。它只需要不到4秒钟,而Java版本需要近20秒。

Java:

import java.util.List;
import java.util.LinkedList;

public class ListTest {
  public static int add(List<Integer> list, Integer sum) {
    if (list.isEmpty()) {
        return sum;
    } else {
        int headVal = list.remove(0);
        return add(list, sum + headVal);
    }
}

  public static void main(String[] args) {
    List<Integer> list = new LinkedList<>();
    int sum = 0;

    long start = System.nanoTime();
    for(int j = 0; j < 1000000; j++) {
      list.clear();
      for(int i = 1; i <= 1000; i++) list.add(i);
      sum = add(list, 0);
    }
    long end = System.nanoTime();
    System.out.println("time = " + ((end - start)/1e6) + "ms");
    System.out.println("sum = " + sum);
  }
}

Scala:

object ListTest {
  def add(list : List[Int],sum:Int): Int = {
    if (list.isEmpty) {
        sum
    } else {
        val headVal  = list.head
        add(list.tail, sum + headVal)
    }
  }

  def 

main(args: Array[String]) {
    val list = List.range(1, 1001)
    var sum = 0

    val start = System.nanoTime
    for(i <- 1 to 1000000) sum = add(list, 0);
    val end = System.nanoTime

    println("time = " + ((end - start)/1e6) + "ms")
    println("sum = " + sum)
  }
}

为什么你的Java版本会在测量循环内将所有元素添加到列表中,但Scala不会? - dhg
@dhg 因为 Java 版本使用了 remove,并且在求和后列表为空。 - Beryllium
你在测量中包含了列表构建。自然而然,Java的速度会变慢。你也没有从测量中排除预热。总的来说,结果展示了一些相当无关紧要和不太有趣的东西。 - Marko Topolnik

0
给定的 Scala 示例可能会得到改进:使用 Scala 中的尾递归函数,通常需要使用 @tailrec 注释来提供编译器提示。有关更多信息,请参见此处。 [注:由于拼写错误已更新]

3
你的意思是说“…在Scala中”,是吗?据我所知,@tailrec注解来自于Scala而不是Java,不是吗? - Christian Hujer
你说得完全正确。已经修复了。感谢你的提醒! - Timothy Perrigo
2
注释不是编译器提示,而是检查:您假设已注释的函数是尾递归的。如果不是,编译器将会报错。 - Jens Schauder

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接