如何在Scala中优化for推导式和循环?

130
所以Scala应该和Java一样快。我正在用Scala重新解决一些我最初用Java解决的欧拉计划问题,具体来说是问题5:“什么是能被1到20之间所有数字整除的最小正整数?”
这是我的Java解决方案,在我的机器上完成需要0.7秒:
public class P005_evenly_divisible implements Runnable{
    final int t = 20;

    public void run() {
        int i = 10;
        while(!isEvenlyDivisible(i, t)){
            i += 2;
        }
        System.out.println(i);
    }

    boolean isEvenlyDivisible(int a, int b){
        for (int i = 2; i <= b; i++) {
            if (a % i != 0) 
                return false;
        }
        return true;
    }

    public static void main(String[] args) {
        new P005_evenly_divisible().run();
    }
}

这是我将内容直接翻译成Scala的结果,需要103秒(比原文慢147倍!)

object P005_JavaStyle {
    val t:Int = 20;
    def run {
        var i = 10
        while(!isEvenlyDivisible(i,t))
            i += 2
        println(i)
    }
    def isEvenlyDivisible(a:Int, b:Int):Boolean = {
        for (i <- 2 to b)
            if (a % i != 0)
                return false
        return true
    }
    def main(args : Array[String]) {
        run
    }
}

最后,这是我尝试函数式编程的结果,用时39秒(比原来慢了55倍)

object P005 extends App{
    def isDivis(x:Int) = (1 to 20) forall {x % _ == 0}
    def find(n:Int):Int = if (isDivis(n)) n else find (n+2)
    println (find (2))
}

我在Windows 7 64位系统上使用Scala 2.9.0.1。如何提高性能?我是否做错了什么?或者Java只是更快吗?


2
你是使用Scala Shell编译还是解释执行? - ahmet alp balkan
2
你没有展示你是如何计时的。你尝试过只计时run方法吗? - Aaron Novstrup
2
@hammar - 是的,我刚刚用笔和纸的方式做了它:从高数开始写下每个数字的质数因子,然后划掉您已经拥有的更高数字的因子,这样您最终得到(522)(19)(33)(17)(22)()(7)(13)()*(11) = 232792560。 - Luigi Plinge
2
+1 这是我在 SO 上数周内看到的最有趣的问题(也有我看过的最好的答案之一)。 - Mia Clarke
@Matt,@Andrew,我通常在我的 Java 类中实现Runnable接口,这些类是用于运行的:这样做更符合概念上的含义,也可以轻松地从其他地方(例如Swing GUI)启动新线程。但是对于这个讨论,我应该把它排除在外,因为当我们使用"main"时,它不相关。 - Luigi Plinge
显示剩余5条评论
8个回答

112

这个特定情况的问题在于你从for表达式内部返回。这进而被转换为NonLocalReturnException的抛出,该异常在封闭方法中被捕获。优化器可以消除foreach,但目前无法消除throw / catch。而throw / catch是昂贵的。但是由于这样的嵌套返回在Scala程序中很少见,因此优化器尚未解决这个问题。正在进行改进优化器的工作,希望能很快解决此问题。


9
回归变成了例外,这真的很严重。我确定这已经有记录了,但它给人一种难以理解的神秘感。那真的只有这种方法吗? - skrebbel
10
如果返回操作发生在闭包内部,似乎这是最好的可用选项。从闭包外部返回当然会被直接翻译成字节码中的返回指令。 - Martin Odersky
1
我相信我忽略了什么,但为什么不改为在闭包内编译返回值以设置一个封闭的布尔标志和返回值,并在闭包调用返回后检查它呢? - Luke Hutteman
9
他的功能性算法为什么还是比原来慢55倍?看起来它不应该遭受如此可怕的性能问题。 - Elijah
7
现在,2014年,我再次进行了测试,我的性能如下:java -> 0.3秒;scala -> 3.6秒;scala优化 -> 3.5秒;scala函数式 -> 4秒;看起来比三年前好多了,但是……差距仍然太大。我们能期望更多的性能改进吗?换句话说,Martin,在理论上,还有可能进行优化的地方吗? - sasha.sochka
显示剩余2条评论

80
问题很可能是方法isEvenlyDivisible中使用了for推导式。将for替换为等效的while循环可消除与Java之间的性能差异。
与Java的for循环不同,Scala的for推导实际上是高阶方法的语法糖;在这种情况下,您正在调用Range对象上的foreach方法。Scala的for非常通用,但有时会导致性能下降。
您可能希望尝试Scala 2.9版本中的-optimize标志。观察到的性能可能取决于使用的特定JVM以及JIT优化器具有足够的“预热”时间来识别和优化热点。
最近在邮件列表上的讨论表明,Scala团队正在努力改进简单情况下的for性能:

以下是问题跟踪器中的问题: https://issues.scala-lang.org/browse/SI-4633

更新 5/28

  • 作为短期解决方案,ScalaCL 插件(alpha 版)将简单的 Scala 循环转换为等效的 while 循环。
  • 作为潜在的长期解决方案,EPFL 和 Stanford 的团队正在 合作开发一个项目,实现对于非常高性能的 "虚拟" Scala 的运行时编译。例如,多个惯用的函数式循环可以在运行时 融合为最佳的 JVM 字节码,或者到另一个目标,比如 GPU。该系统具有可扩展性,允许用户定义 DSL 和变换。请查看 公开出版物 和 Stanford 课程笔记。初步代码已经在 Github 上发布,计划在未来几个月内发布正式版本。

6
太棒了,我用while循环替换了for推导式,速度和Java版本一模一样(误差在1%以内)。感谢……我几乎对Scala失去了信心!现在只需要研究一个好的函数式算法... :) - Luigi Plinge
25
值得注意的是,尾递归函数和 while 循环一样快(因为它们都转换为非常相似或相同的字节码)。 - Rex Kerr
7
这也曾经困扰过我。由于速度极慢,我不得不将一个算法从使用集合函数转换为嵌套的while循环(第6级!)这是需要大力解决的问题,我认为;如果我需要良好(注意:不是极快)的性能,那么优雅的编程风格有何用处呢? - Raphael
7
那么,什么情况下适合使用 "for"? - OscarRyz
1
我尝试了ScalaCL,并且在我的电脑上将上面的函数版本降至1.89秒。这是超过20倍的速度提升!虽然不如尾递归好,但也差不多,并且更加简洁。 - Luigi Plinge
显示剩余2条评论

32
作为跟进,我尝试了 -optimize 标志,并将运行时间从103秒降至76秒,但仍然比Java或while循环慢107倍。
然后我看了“函数式”版本:
object P005 extends App{
  def isDivis(x:Int) = (1 to 20) forall {x % _ == 0}
  def find(n:Int):Int = if (isDivis(n)) n else find (n+2)
  println (find (2))
}

我正在尝试找出如何以简洁的方式消除“forall”。但是我失败了,只能提供以下内容:

object P005_V2 extends App {
  def isDivis(x:Int):Boolean = {
    var i = 1
    while(i <= 20) {
      if (x % i != 0) return false
      i += 1
    }
    return true
  }
  def find(n:Int):Int = if (isDivis(n)) n else find (n+2)
  println (find (2))
}

我的巧妙5行代码的解决方案变成了12行。然而,这个版本运行时间为0.71秒,与原始的Java版本运行速度相同,比使用“forall”(40.2秒)的上面的版本快56倍!(关于为什么它比Java更快,请参见下面的编辑)

显然,我的下一步是将上述内容翻译回Java,但是Java无法处理它,并在n约22000时抛出StackOverflowError。

然后,我想了一会儿,用更多的尾递归替换了“while”,这样可以节省几行代码,运行速度也同样快,但让我们面对现实,更难理解:

object P005_V3 extends App {
  def isDivis(x:Int, i:Int):Boolean = 
    if(i > 20) true
    else if(x % i != 0) false
    else isDivis(x, i+1)

  def find(n:Int):Int = if (isDivis(n, 2)) n else find (n+2)
  println (find (2))
}

因此,Scala的尾递归胜出了,但我感到惊讶的是,“for”循环(和“forall”方法)这么简单的东西实际上是有缺陷的,必须用冗长且不雅的“while”或者尾递归来替代。我尝试使用Scala的很大原因是它简洁的语法,但如果我的代码运行速度会慢100倍,那就没什么用了!

编辑:(已删除)

编辑编辑:2.5秒和0.7秒之间的运行时间差异完全是由于使用32位还是64位JVM引起的。从命令行运行Scala时,将使用JAVA_HOME设置的任何内容,而Java则始终使用64位(如果有可用的话)。IDE具有自己的设置。在这里可以找到一些测量结果:Scala execution times in Eclipse

1
isDivis方法可以写成:def isDivis(x: Int, i: Int): Boolean = if (i > 20) true else if (x % i != 0) false else isDivis(x, i+1)。请注意,在Scala中,if-else是一个表达式,它总是返回一个值。这里不需要return关键字。 - kiritsuku
3
您的最新版本(P005_V3)可以通过以下方式进行简化、更加简洁和清晰,例如编写: def isDivis(x: Int, i: Int): Boolean = (i > 20) || (x % i == 0) && isDivis(x, i+1) - Blaisorblade
@Blaisorblade 不行。这会破坏尾递归性,而尾递归性是将其转换为字节码中的 while 循环所必需的,进而使得执行速度更快。 - gzm0
4
我明白你的观点,但我的例子仍然是尾递归,因为&&和||使用短路求值,这一点已通过使用@tailrec进行确认: https://gist.github.com/Blaisorblade/5672562 - Blaisorblade

8
有关for推导式的答案是正确的,但这并不是全部。你应该注意到在isEvenlyDivisible中使用return是有代价的。在for内部使用return会强制Scala编译器生成一个非局部返回(即在其函数外返回)。
这是通过使用异常退出循环来完成的。如果您构建自己的控制抽象,例如:
def loop[T](times: Int, default: T)(body: ()=>T) : T = {
    var count = 0
    var result: T = default
    while(count < times) {
        result = body()
        count += 1
    }
    result
}

def foo() : Int= {
    loop(5, 0) {
        println("Hi")
        return 5
    }
}

foo()

这段代码只会打印一次“Hi”。
需要注意的是,在函数foo中使用return语句会退出foo,这是你所期望的。由于括号中的表达式是一个函数字面量,你可以在loop的签名中看到这一点,这迫使编译器生成一个非局部返回,也就是说,return语句强制你退出的是foo而不仅仅是body
在Java(即JVM)中实现这种行为的唯一方法是抛出异常。
回到isEvenlyDivisible:
def isEvenlyDivisible(a:Int, b:Int):Boolean = {
  for (i <- 2 to b) 
    if (a % i != 0) return false
  return true
}

if (a % i != 0) return false 是一个带有返回值的函数字面量,因此每次遇到返回语句时,运行时都需要抛出和捕获异常,这会导致相当大的垃圾回收开销。


7

我发现了一些加速 forall 方法的方法:

原始时间:41.3秒

def isDivis(x:Int) = (1 to 20) forall {x % _ == 0}

预先实例化范围,以避免每次创建新范围: 9.0秒

val r = (1 to 20)
def isDivis(x:Int) = r forall {x % _ == 0}

将Range转换为List:4.8秒

val rl = (1 to 20).toList
def isDivis(x:Int) = rl forall {x % _ == 0}

我尝试了几个其他的集合,但List是最快的(虽然仍然比避免使用Range和高阶函数慢7倍)。

虽然我对Scala还很陌生,但我猜编译器可以通过自动将方法中的Range文本替换为外部作用域的Range常量(如上所示)来实现快速且显著的性能提升。或者更好的方式是像Java中的字符串文本一样进行内部化。


: 数组与Range差不多,但有趣的是,增强一个新的forall方法(如下所示)在64位上执行速度快24%,在32位上快8%。当我通过将因子数量从20个减少到15个来减小计算大小时,差异消失了,因此可能是垃圾回收效应。无论原因是什么,在长时间的满负荷运行下都是显著的。

对于List的类似增强也使性能提高了约10%。

  val ra = (1 to 20).toArray
  def isDivis(x:Int) = ra forall2 {x % _ == 0}

  case class PimpedSeq[A](s: IndexedSeq[A]) {
    def forall2 (p: A => Boolean): Boolean = {      
      var i = 0
      while (i < s.length) {
        if (!p(s(i))) return false
        i += 1
      }
      true
    }    
  }  
  implicit def arrayToPimpedSeq[A](in: Array[A]): PimpedSeq[A] = PimpedSeq(in)  

3
我想对那些因为类似问题而失去对Scala的信心的人发表评论,这种问题在几乎所有函数式语言的性能中都会出现。如果你正在优化Haskell中的fold操作,你经常需要将其重写为递归尾调用优化循环,否则你将面临性能和内存问题。
我知道FP还没有被优化到我们不必考虑这样的事情的程度,但这绝不是Scala特有的问题。

2

已经讨论了与Scala相关的特定问题,但主要问题在于使用蛮力算法并不是很酷。考虑这个(比原始Java代码快得多):

def gcd(a: Int, b: Int): Int = {
    if (a == 0)
        b
    else
        gcd(b % a, a)
}
print (1 to 20 reduce ((a, b) => {
  a / gcd(a, b) * b
}))

1
这个问题比较了特定逻辑在不同编程语言中的性能。算法是否最优并不重要。 - smartnut007

1

尝试在解决方案中给出的一行代码Scala for Project Euler

虽然远离while循环,但所需时间至少比你的快.. :)


它与我的函数式版本非常相似。你可以将我的写成 def r(n:Int):Int = if ((1 to 20) forall {n % _ == 0}) n else r (n+2); r(2),这比Pavel的代码少了4个字符。 :) 不过我并不觉得我的代码有多好 - 当我发布这个问题时,我只写了大约30行Scala代码。 - Luigi Plinge

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接