在Scala中是否有一种通用的记忆化方法?

60
我想要进行记忆化处理:
def fib(n: Int) = if(n <= 1) 1 else fib(n-1) + fib(n-2)
println(fib(100)) // times out

我写了以下代码,令人惊讶的是它编译并且可以运行(我感到惊讶是因为fib在声明中引用了自身):

case class Memo[A,B](f: A => B) extends (A => B) {
  private val cache = mutable.Map.empty[A, B]
  def apply(x: A) = cache getOrElseUpdate (x, f(x))
}

val fib: Memo[Int, BigInt] = Memo {
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2) 
}

println(fib(100))     // prints 100th fibonacci number instantly

但是当我尝试在def内部声明fib时,就会出现编译器错误:
def foo(n: Int) = {
  val fib: Memo[Int, BigInt] = Memo {
    case 0 => 0
    case 1 => 1
    case n => fib(n-1) + fib(n-2) 
  }
  fib(n)
} 

以上代码无法编译:错误:前置引用扩展了值fib的定义 case n => fib(n-1) + fib(n-2)

为什么在def内部声明val fib会失败,但在class/object范围之外声明则可以成功?

为了澄清,为什么我想在def范围内声明递归记忆函数 - 这是我解决子集和问题的方法:

/**
   * Subset sum algorithm - can we achieve sum t using elements from s?
   *
   * @param s set of integers
   * @param t target
   * @return true iff there exists a subset of s that sums to t
   */
  def subsetSum(s: Seq[Int], t: Int): Boolean = {
    val max = s.scanLeft(0)((sum, i) => (sum + i) max sum)  //max(i) =  largest sum achievable from first i elements
    val min = s.scanLeft(0)((sum, i) => (sum + i) min sum)  //min(i) = smallest sum achievable from first i elements

    val dp: Memo[(Int, Int), Boolean] = Memo {         // dp(i,x) = can we achieve x using the first i elements?
      case (_, 0) => true        // 0 can always be achieved using empty set
      case (0, _) => false       // if empty set, non-zero cannot be achieved
      case (i, x) if min(i) <= x && x <= max(i) => dp(i-1, x - s(i-1)) || dp(i-1, x)  // try with/without s(i-1)
      case _ => false            // outside range otherwise
    }

    dp(s.length, t)
  }

3
请查看我的博客文章,了解递归函数备忘录另一种变体的实现方法。 - michid
2
在我发布任何SO之前,我会先谷歌一下,你的博客文章是第一个结果 :) 我同意这是“正确”的方法 - 使用Y组合器。但是,我认为使用我的风格并利用lazy val看起来比每个函数都有2个定义(递归和Y组合)更干净。看看这个链接多么干净。 - pathikrit
我对您上述问题中一些语法的简明性感到困惑(特别是case类使用“extend(A => B)”)。我发布了一个关于此的问题:https://dev59.com/hmIk5IYBdhLWcg3wI7EF - chaotic3quilibrium
请谨慎使用此模式,因为它可能会带来与“Map”相关的并发问题:https://dev59.com/rFnUa4cB1Zd3GeqPbYyn#6807324 - lcn
问题主体和被接受的答案与此问题的标题无关。您能否更改标题? - user239558
4个回答

55

我发现了一种更好的使用Scala进行记忆化的方法:

def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {
  override def apply(key: I) = getOrElseUpdate(key, f(key))
}

现在你可以按照以下方式编写斐波那契数列:
lazy val fib: Int => BigInt = memoize {
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2)
}

这里是一个有多个参数的例子(choose函数):

lazy val c: ((Int, Int)) => BigInt = memoize {
  case (_, 0) => 1
  case (n, r) if r > n/2 => c(n, n - r)
  case (n, r) => c(n - 1, r - 1) + c(n - 1, r)
}

以下是子集和问题:

// is there a subset of s which has sum = t
def isSubsetSumAchievable(s: Vector[Int], t: Int) = {
  // f is (i, j) => Boolean i.e. can the first i elements of s add up to j
  lazy val f: ((Int, Int)) => Boolean = memoize {
    case (_, 0) => true        // 0 can always be achieved using empty list
    case (0, _) => false       // we can never achieve non-zero if we have empty list
    case (i, j) => 
      val k = i - 1            // try the kth element
      f(k, j - s(k)) || f(k, j)
  }
  f(s.length, t)
}

编辑:如下所讨论,这里是一个线程安全版本

def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {self =>
  override def apply(key: I) = self.synchronized(getOrElseUpdate(key, f(key)))
}

2
我认为这个(或者大多数基于mutable.Map的实现)不是线程安全的?但如果在单线程环境中使用,看起来语法很好。 - Gary Coady
我不确定可变的HashMap实现是否会导致崩溃和/或损坏数据,或者主要问题只是缺少更新;对于大多数用例来说,缺少更新可能是可以接受的。 - Gary Coady
2
我想知道在TrieMap上是否可能出现死锁。毕竟,在getOrElseUpdate方法内部,该映射被“递归”访问。 - VasiliNovikov
@VasyaNovikov:然后我们可以通过在 self.synchronized {getOrElseUpdate} 周围包装锁来使锁更加粗略。 - pathikrit
2
@pathikrit:我认为使用mutable.HashMap的self.synchronized版本没有任何问题。我的评论主要是对上面评论中关于TrieMap的讨论进行澄清,因为事实证明不能简单地将TrieMap替换为给定代码中的内容。 - Jeff Klukas
显示剩余6条评论

22

类/特征级别的val编译为方法和私有变量的组合,因此允许递归定义。

另一方面,局部val只是普通变量,因此不允许递归定义。

顺便说一下,即使你定义的def有效,它也无法达到你的预期。每次调用foo时,都会创建一个新的函数对象fib并具有自己的备份映射。如果您真的希望def成为公共接口,则应该执行以下操作:

private val fib: Memo[Int, BigInt] = Memo {
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2) 
}

def foo(n: Int) = {
  fib(n)
} 

“foo”和“fib”只是简化 - 在我的情况下,foo是子集和问题,而fib是递归记忆化输入集的函数,因此我不能简单地将我的记忆化函数提取到方法之外。您能解释一下“类级val编译为方法和私有变量的组合”的部分是什么意思吗?类和方法val之间还有哪些差异需要注意? - pathikrit
i) 有什么阻止你将其从方法之外提取的吗? ii) 当您在类/特质级别编写val x = N时,您得到的是def x = _xprivate val _x = N。您应该在任何Scala书中找到这个解释。我想不起来字段val和本地val之间的其他区别了。 - missingfaktor
9
在局部范围内,您可以使用以下解决方法:将fib定义为lazy val。然后,您应该能够在局部范围内进行递归调用。 - missingfaktor
如果它使用了可变状态和val,那么它是否意味着它不是线程安全的? - ses
@ses,除非该可变状态具有线程安全保证。 (您可以是可变的和线程安全的。只是...更加困难。) - missingfaktor
如果您能展示如何制作通用的n元函数,我们可以提供更多的赞成票。 - user48956

11

Scalaz已经有了一个解决方案,为什么不重复利用呢?

import scalaz.Memo
lazy val fib: Int => BigInt = Memo.mutableHashMapMemo {
  case 0 => 0
  case 1 => 1
  case n => fib(n-2) + fib(n-1)
}

你可以阅读更多有关于Scalaz中的记忆化


1
可变的HashMap不是线程安全的。此外,为基本条件单独定义case语句似乎是不必要的特殊处理,相反,可以将Map加载初始值并传递给Memoizer。以下是Memoizer的签名,它接受一个memo(不可变Map)和公式,并返回一个递归函数。
Memoizer看起来像:
def memoize[I,O](memo: Map[I, O], formula: (I => O, I) => O): I => O

现在给定以下斐波那契公式:
def fib(f: Int => Int, n: Int) = f(n-1) + f(n-2)

斐波那契数列与Memoizer结合可以定义为:
val fibonacci = memoize( Map(0 -> 0, 1 -> 1), fib)

其中,上下文无关的通用记忆化函数定义如下:

    def memoize[I, O](map: Map[I, O], formula: (I => O, I) => O): I => O = {
        var memo = map
        def recur(n: I): O = {
          if( memo contains n) {
            memo(n) 
          } else {
            val result = formula(recur, n)
            memo += (n -> result)
            result
          }
        }
        recur
      }

同样地,对于阶乘,一个公式是

def fac(f: Int => Int, n: Int): Int = n * f(n-1)

使用Memoizer的阶乘是:
val factorial = memoize( Map(0 -> 1, 1 -> 1), fac)

灵感:记忆函数,摘自Douglas Crockford的《Javascript语言精粹》第4章


为基本情况单独定义案例语句似乎是不必要的特殊处理。真的吗?实际上,斐波那契数列是简单基本情况的罕见例子。你会如何使用此方法解决背包问题(https://github.com/pathikrit/scalgos/blob/master/src/main/scala/com/github/pathikrit/scalgos/DynamicProgramming.scala#L103)? - pathikrit
在斐波那契数列或任何已知值的情况下,应该将其预加载到映射中。这使得公式函数更接近于其数学定义,我认为是这样的。如果公式需要比较(如case语句或if...else块),例如解决背包问题,使用case语句是完全可以的。 - Boolean

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接