在Scala中,用什么类型来存储一个可变的内存数据表?

20
每次调用函数时,如果对于给定的参数值它的结果尚未被记忆化,则我希望将结果放入内存表中。一个列用于存储结果,其他列用于存储参数值。
如何最好地实现这个功能?参数具有多种类型,包括一些枚举。
在C#中,我通常会使用 DataTable。在Scala中是否有相当的东西?

2
如果你在网上搜索“Scala函数记忆化”,你会发现有几篇关于这个主题的文章。 - Randall Schulz
5个回答

27
你可以使用一个mutable.Map[TupleN[A1, A2, ..., AN], R],以下定义(基于michid's blog中的记忆化代码)允许你轻松地对具有多个参数的函数进行记忆化。例如:
import Memoize._

def reallySlowFn(i: Int, s: String): Int = {
   Thread.sleep(3000)
   i + s.length
}

val memoizedSlowFn = memoize(reallySlowFn _)
memoizedSlowFn(1, "abc") // returns 4 after about 3 seconds
memoizedSlowFn(1, "abc") // returns 4 almost instantly

定义:

/**
 * A memoized unary function.
 *
 * @param f A unary function to memoize
 * @param [T] the argument type
 * @param [R] the return type
 */
class Memoize1[-T, +R](f: T => R) extends (T => R) {
   import scala.collection.mutable
   // map that stores (argument, result) pairs
   private[this] val vals = mutable.Map.empty[T, R]

   // Given an argument x, 
   //   If vals contains x return vals(x).
   //   Otherwise, update vals so that vals(x) == f(x) and return f(x).
   def apply(x: T): R = vals getOrElseUpdate (x, f(x))
}

object Memoize {
   /**
    * Memoize a unary (single-argument) function.
    *
    * @param f the unary function to memoize
    */
   def memoize[T, R](f: T => R): (T => R) = new Memoize1(f)

   /**
    * Memoize a binary (two-argument) function.
    * 
    * @param f the binary function to memoize
    * 
    * This works by turning a function that takes two arguments of type
    * T1 and T2 into a function that takes a single argument of type 
    * (T1, T2), memoizing that "tupled" function, then "untupling" the
    * memoized function.
    */
   def memoize[T1, T2, R](f: (T1, T2) => R): ((T1, T2) => R) = 
      Function.untupled(memoize(f.tupled))

   /**
    * Memoize a ternary (three-argument) function.
    *
    * @param f the ternary function to memoize
    */
   def memoize[T1, T2, T3, R](f: (T1, T2, T3) => R): ((T1, T2, T3) => R) =
      Function.untupled(memoize(f.tupled))

   // ... more memoize methods for higher-arity functions ...

   /**
    * Fixed-point combinator (for memoizing recursive functions).
    */
   def Y[T, R](f: (T => R) => T => R): (T => R) = {
      lazy val yf: (T => R) = memoize(f(yf)(_))
      yf
   }
}

固定点组合子(Memoize.Y)使得可以对递归函数进行记忆化:
val fib: BigInt => BigInt = {                         
   def fibRec(f: BigInt => BigInt)(n: BigInt): BigInt = {
      if (n == 0) 1 
      else if (n == 1) 1 
      else (f(n-1) + f(n-2))                           
   }                                                     
   Memoize.Y(fibRec)
}

[1] WeakHashMap不适合用作缓存。请参见http://www.codeinstructions.com/2008/09/weakhashmap-is-not-cache-understanding.html这个相关问题


请注意,上述实现不是线程安全的,因此如果您需要从多个线程缓存某些计算,则可能会出现问题。为了将其更改为线程安全,请执行以下操作: private[this] val vals = new HashMap[T, R] with SynchronizedMap[T, R] - Grega Kešpret
1
有另一种递归函数记忆化的方法:https://dev59.com/U18f5IYBdhLWcg3wB-sU#25129872,它不需要使用Y组合器或者制定一个非递归形式,这对于具有多个参数的递归函数可能会很困难。实际上,这两种方法都依赖于Scala对函数递归的支持,即在使用Y组合器时,“yf”调用“yf”,而在链接的wrick变体中,记忆化函数将调用自身。 - lcn

10

anovstrup提出的使用可变Map的版本基本上与C#中的相同,因此易于使用。

但是,如果您希望还可以使用更多的函数式风格。它使用不可变映射(Maps),作为一种累加器。在示例中,将元组(而不是Int)用作键与可变情况完全相同。

def fib(n:Int) = fibM(n, Map(0->1, 1->1))._1

def fibM(n:Int, m:Map[Int,Int]):(Int,Map[Int,Int]) = m.get(n) match {
   case Some(f) => (f, m)
   case None => val (f_1,m1) = fibM(n-1,m)
                val (f_2,m2) = fibM(n-2,m1)
                val f = f_1+f_2
                (f, m2 + (n -> f))   
}

当然,这有点复杂,但这是一个很有用的技巧(请注意,上面的代码旨在清晰易懂,而不是速度快)。


3
作为这个主题的新手,我无法完全理解所给出的例子(但仍然要感谢)。恭敬地说,我会针对那些与我水平相同且有相同问题的人提供自己的解决方案。我认为我的代码对于任何只具备非常基础的 Scala 知识的人来说都是清晰易懂的。


def MyFunction(dt : DateTime, param : Int) : Double
{
  val argsTuple = (dt, param)
  if(Memo.contains(argsTuple)) Memo(argsTuple) else Memoize(dt, param, MyRawFunction(dt, param))
}
def MyRawFunction(dt : DateTime, param : Int) : Double { 1.0 // 这里进行了繁重的计算/查询 }
def Memoize(dt : DateTime, param : Int, result : Double) : Double { Memo += (dt, param) -> result result }
val Memo = new scala.collection.mutable.HashMap[(DateTime, Int), Double]

运行得非常完美。如果我错过了什么,请指出。


1
我在我的解决方案中添加了一些注释,希望能为您澄清它。我所提出的方法的优点是它允许您记忆化任何函数(好吧,有一些注意事项,但很多函数)。有点像您在相关问题中发布的记忆化关键字。 - Aaron Novstrup
2
可能仍然神秘的一个方面是固定点组合器 - 对此,我鼓励您阅读michid的博客,喝很多咖啡,并可能与一些函数式编程文本建立友好关系。好消息是,只有在记忆递归函数时才需要它。 - Aaron Novstrup

1
使用可变映射表进行记忆化时,需要注意这将导致典型的并发问题,例如在写入尚未完成时进行获取。然而,线程安全的记忆化尝试建议这样做,如果没有或几乎没有价值。

以下线程安全代码创建了一个记忆化的 斐波那契 函数,初始化了几个线程(从 'a' 到 'd' 命名),并对其进行调用。多次尝试代码(在REPL中),可以很容易地看到 f(2) set 被打印了不止一次。这意味着线程A已经启动了 f(2) 的计算,但线程B完全不知道,并开始自己的复制计算。这种无知在缓存构建阶段是如此普遍,因为所有线程都看不到已建立的子解决方案,都会进入 else 子句。

object ScalaMemoizationMultithread {

  // do not use case class as there is a mutable member here
  class Memo[-T, +R](f: T => R) extends (T => R) {
    // don't even know what would happen if immutable.Map used in a multithreading context
    private[this] val cache = new java.util.concurrent.ConcurrentHashMap[T, R]
    def apply(x: T): R =
      // no synchronized needed as there is no removal during memoization
      if (cache containsKey x) {
        Console.println(Thread.currentThread().getName() + ": f(" + x + ") get")
        cache.get(x)
      } else {
        val res = f(x)
        Console.println(Thread.currentThread().getName() + ": f(" + x + ") set")
        cache.putIfAbsent(x, res) // atomic
        res
      }
  }

  object Memo {
    def apply[T, R](f: T => R): T => R = new Memo(f)

    def Y[T, R](F: (T => R) => T => R): T => R = {
      lazy val yf: T => R = Memo(F(yf)(_))
      yf
    }
  }

  val fibonacci: Int => BigInt = {
    def fiboF(f: Int => BigInt)(n: Int): BigInt = {
      if (n <= 0) 1
      else if (n == 1) 1
      else f(n - 1) + f(n - 2)
    }

    Memo.Y(fiboF)
  }

  def main(args: Array[String]) = {
    ('a' to 'd').foreach(ch =>
      new Thread(new Runnable() {
        def run() {
          import scala.util.Random
          val rand = new Random
          (1 to 2).foreach(_ => {
            Thread.currentThread().setName("Thread " + ch)
            fibonacci(5)
          })
        }
      }).start)
  }
}

0
除了Landei的回答,我还想建议在Scala中进行DP的自底向上(非记忆化)方法是可行的,核心思想是使用foldLeft(s)。
计算斐波那契数列的示例:
  def fibo(n: Int) = (1 to n).foldLeft((0, 1)) {
    (acc, i) => (acc._2, acc._1 + acc._2)
  }._1

最长递增子序列的示例

def longestIncrSubseq[T](xs: List[T])(implicit ord: Ordering[T]) = {
  xs.foldLeft(List[(Int, List[T])]()) {
    (memo, x) =>
      if (memo.isEmpty) List((1, List(x)))
      else {
        val resultIfEndsAtCurr = (memo, xs).zipped map {
          (tp, y) =>
            val len = tp._1
            val seq = tp._2
            if (ord.lteq(y, x)) { // current is greater than the previous end
              (len + 1, x :: seq) // reversely recorded to avoid O(n)
            } else {
              (1, List(x)) // start over
            }
        }
        memo :+ resultIfEndsAtCurr.maxBy(_._1)
      }
  }.maxBy(_._1)._2.reverse
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接