在Clojure中,是否可能结合记忆化和尾调用优化?

15

在Clojure中,我想编写一个尾递归函数,可以记忆其中间结果以供后续调用。

[编辑:本问题已使用gcd作为示例重写,而不是factorial。]

可将记忆化的gcd(最大公约数)实现如下:

(def gcd (memoize (fn [a b] 
   (if (zero? b) 
       a 
       (recur b (mod a b)))) 
在这个实现中,中间结果不会被记忆化以供subsequent调用使用。例如,为了计算gcd(9,6),会调用gcd(6,3)作为中间结果。但是gcd(6,3)没有被存储在记忆化函数的缓存中,因为recur的递归点是未被记忆化的匿名函数。

因此,如果在调用gcd(9,6)之后,我们调用gcd(6,3),我们将无法从记忆化中获益。

我唯一能想到的解决方案是使用普通递归(显式调用gcd而不是recur),但这样我们就无法受益于尾调用优化。

重要结论

有没有办法同时实现:

  1. 尾调用优化
  2. 记忆中间结果以供后续调用使用

注释

  1. 这个问题类似于Combine memoization and tail-recursion。但那里的所有答案都与F#有关。在这里,我正在寻找clojure的答案。
  2. 这个问题留给读者作为The Joy of Clojure(第12.4章)的练习题。您可以在http://bit.ly/HkQrio上查看相关页面。
5个回答

8

在你的情况下,很难展示记忆化对阶乘函数有任何作用,因为中间调用都是唯一的,所以我将重写一个有点牵强的例子,假设重点是探索避免堆栈溢出的方法:

(defn stack-popper [n i] 
    (if (< i n) (* i (stack-popper n (inc i))) 1)) 

然后可以从记忆化中获取一些东西:

(def stack-popper 
   (memoize (fn [n i] (if (< i n) (* i (stack-popper n (inc i))) 1))))

不让堆栈溢出的一般方法有:

使用尾调用

(def stack-popper 
    (memoize (fn [n acc] (if (> n 1) (recur (dec n) (* acc (dec n))) acc))))

使用 蹦床

(def stack-popper 
    (memoize (fn [n acc] 
        (if (> n 1) #(stack-popper (dec n) (* acc (dec n))) acc))))
(trampoline (stack-popper 4 1))

使用惰性序列
(reduce * (range 1 4))

虽然这些方法并非总是有效,但我还没有遇到过完全无法使用的情况。我几乎总是首选简单懒惰的方法,因为我认为它们最接近Clojure的风格,然后再尝试使用recur或trampolines实现尾递归。


1
当然,我很快会添加一些示例。 - Arthur Ulfeldt
懒惰序列版本不是记忆化的好例子。 - Arthur Ulfeldt
(* 4 (fact 3)) -> (* 3 (fact 2)) -> (* 2 (fact 1)) -> 1: fib 作为更好的记忆化示例。 - Arthur Ulfeldt

2
(defmacro memofn
  [name args & body]
  `(let [cache# (atom {})]
     (fn ~name [& args#]
       (let [update-cache!# (fn update-cache!# [state# args#]
                              (if-not (contains? state# args#)
                                (assoc state# args#
                                       (delay
                                         (let [~args args#]
                                           ~@body)))
                                state#))]
         (let [state# (swap! cache# update-cache!# args#)]
           (-> state# (get args#) deref))))))

这将允许递归定义一个记忆化函数,同时缓存中间结果。用法:

(def fib (memofn fib [n]
           (case n
             1 1
             0 1
             (+ (fib (dec n)) (fib (- n 2))))))

为什么这比标准的“memoize”函数更好? - viebel
1
@viebel 看看这个讨论 - 经典之作。 - kotarak
@viebel 啊。但恐怕它无法解决您的尾递归问题。 - kotarak
@viebel 我不确定,但我认为由于JVM的限制,这通常是不可能的。但是,您可以编写自己的专用函数,将逻辑和记忆化合并到一个函数中。然后,我想recur会起作用。但是,如果多个线程同时调用该函数,则可能会出现一些问题。在这种情况下,您需要锁定或类似的东西来协调对缓存的访问。我还没有仔细考虑过这个问题。 - kotarak
我不需要线程安全。您能否详细说明如何“编写自己的专用函数,将逻辑和记忆化结合在一个函数中”? - viebel
显示剩余2条评论

2
(def gcd 
  (let [cache (atom {})]
    (fn [a b]
      @(or (@cache [a b])
         (let [p (promise)]
           (deliver p
             (loop [a a b b]
               (if-let [p2 (@cache [a b])]
                 @p2
                 (do
                   (swap! cache assoc [a b] p)
                   (if (zero? b) 
                     a 
                     (recur b (mod a b))))))))))))

存在一些并发问题(双重评估,与memoize相同的问题,但由于承诺更糟)可能会使用@kotarak的建议进行修复。

将上述代码转换为宏留给读者作为练习。(Fogus的注释是我认为的玩笑话。)

将其转换为宏实际上是宏学中的简单练习,请注意主体(最后3行)保持不变。


0

使用Clojure的recur,您可以编写阶乘函数,使用一个没有堆栈增长的累加器,并将其进行记忆化:

(defn fact
  ([n]
     (fact n 1))
  ([n acc]
     (if (= 1 n)
       acc
       (recur (dec n)
              (* n acc)))))

2
我认为这不会起作用,因为recur不会调用记忆化函数。相反,它将使用非记忆化的fact作为递归点。这意味着中间结果没有被记忆化。 - viebel
你是正确的,如果你想要那些中间值被记忆化,那么对 recur 的调用可以被替换为对 fact 的调用。对于这种阶乘的实现方式来说,中间值的记忆化并不像对于斐波那契数列那样有价值(乘法很便宜)。 - Kyle Burton

0

这是使用匿名递归尾调用记忆化中间结果实现的阶乘函数。记忆化与函数集成,通过词法闭包传递对共享缓冲区(使用{{link4:Atom}}引用类型实现)的引用。

由于阶乘函数操作自然数,并且连续结果的参数是增量的,因此Vector似乎更适合存储缓冲结果的数据结构。

我们不是将先前计算的结果作为参数(累加器)传递,而是从缓冲区获取它。

(def !                                            ; global variable referring to a function
  (let [m (atom [1 1 2 6 24])]                    ; buffer of results
    (fn [n]                                       ; factorial function definition
      (let [m-count (count @m)]                   ; number of results in a buffer
        (if (< n m-count)                         ; do we have buffered result for n?
          (nth @m n)                              ; · yes: return it
          (loop [cur m-count]                     ; · no: compute it recursively
            (let [r (*' (nth @m (dec cur)) cur)]  ; new result
              (swap! m assoc cur r)               ; store the result
              (if (= n cur)                       ; termination condition:
                r                                 ; · base case
                (recur (inc cur))))))))))         ; · recursive case

(time (do (! 8000) nil))  ; => "Elapsed time: 154.280516 msecs"
(time (do (! 8001) nil))  ; => "Elapsed time: 0.100222 msecs"
(time (do (! 7999) nil))  ; => "Elapsed time: 0.090444 msecs"
(time (do (! 7999) nil))  ; => "Elapsed time: 0.055873 msecs"

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接