Haskell中的记忆化?

149

有没有关于如何高效地解决以下Haskell函数的指针,适用于大于108的数字(n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

我在Haskell中看到了一些使用记忆化来解决斐波那契数列的例子,这些例子涉及计算(惰性地)所有需要的斐波那契数列数字。但是,在这种情况下,对于给定的n,我们只需要计算非常少量的中间结果。

谢谢


111
仅仅意味着这是我在家做的一些工作 :-) - Angel de Vicente
8个回答

279

我们可以通过创建一个可以在次线性时间内索引的结构来高效地完成此操作。

但首先,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

让我们定义f,但使用“开放递归”而不是直接调用自身。

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

使用 fix f 可以得到一个未记忆化的 f

这将使您可以通过调用例如:fix f 123 = 144,来测试小值的f是否符合您的意思。

我们可以通过定义以下内容来进行记忆化:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

这个程序表现还不错,用一个能够记忆中间结果的方式来代替原本需要O(n^3)时间复杂度的算法。

但是仍然需要线性时间来寻找的记忆化答案。这意味着结果如下:

*Main Data.List> faster_f 123801
248604

这些可接受,但结果的可扩展性并不比那更好。我们可以做得更好!

首先,让我们定义一个无限树:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

然后我们将定义一种索引方式,以便我们可以在O(log n)时间内找到一个具有索引n的节点:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

...而且我们可能会发现使用一个充满自然数的树比去操作那些索引更方便:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

既然我们可以进行索引,那么你可以将树转换为列表:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

您可以通过验证toList nats返回[0..]来检查迄今为止的工作。

现在,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

它的作用与上面的列表类似,但是不需要花费线性时间来查找每个节点,而是可以在对数时间内追踪它。

结果更快。

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

事实上,它非常快,以至于你可以使用 Integer 替换上面的Int,几乎瞬间得到非常大的答案。

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

如果需要一个开箱即用的树形记忆化库,请使用MemoTrie

$ stack repl --package MemoTrie
Prelude> import Data.MemoTrie
Prelude Data.MemoTrie> :set -XLambdaCase
Prelude Data.MemoTrie> :{
Prelude Data.MemoTrie| fastest_f' :: Integer -> Integer
Prelude Data.MemoTrie| fastest_f' = memo $ \case
Prelude Data.MemoTrie|   0 -> 0
Prelude Data.MemoTrie|   n -> max n (fastest_f'(n `div` 2) + fastest_f'(n `div` 3) + fastest_f'(n `div` 4))
Prelude Data.MemoTrie| :}
Prelude Data.MemoTrie> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

5
我会尽力进行翻译和澄清内容,同时保持原意。以下是需要翻译的内容:我尝试了这段代码,有趣的是,f_faster似乎比f要慢。我猜那些列表引用真的拖慢了速度。对于nats和index的定义对我来说很神秘,所以我添加了自己的答案,可能会使事情更清晰。 - Pitarou
6
无限列表情况涉及到一个包含111111111个项目的链表。树的情况则涉及到log n * 所达到的节点数。 - Edward Kmett
2
列表版本必须为列表中的所有节点创建thunk,而树版本避免了创建大量thunk的情况。 - Tom Ellis
8
我知道这是一个相当老的帖子,但是为了避免在多次调用之间保存不需要的路径,f_tree 应该在 where 子句中定义吗? - dfeuer
19
将其放入 CAF 中的原因是可以实现跨调用的记忆化。如果我正在进行昂贵的调用并对其进行记忆化,那么我可能会将其留在 CAF 中,因此这里展示了此技术。在实际应用中,永久记忆化的成本与收益之间存在权衡关系。虽然,考虑到问题是如何实现记忆化,如果用一个有意避免跨调用记忆化的技术来回答,那么我认为这可能会引导人们产生误解。如果什么都不做,这里的评论也会指出一些细微差别。 ;) - Edward Kmett
显示剩余7条评论

20

Edward的回答是一个非常精彩的宝石,我已经复制了它,并提供了memoListmemoTree组合器的实现,这些组合器以开放递归形式记忆函数。

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f

12

不是最有效的方法,但可以使用记忆化技术:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

当请求f !! 144时,会检查是否存在f !! 143,但并不计算其精确值。它仍被设置为某个未知计算结果。仅有所需的确切值被计算。

因此,最初程序并不知道计算了多少东西。

f = .... 

当我们发出请求 f !! 12 时,它开始进行一些模式匹配:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

现在它开始计算

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

这会对 f 递归地提出另一个要求, 因此我们计算

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

现在我们可以逐步向上回溯一些

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

这意味着程序现在知道:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

持续向上滴漏:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

这意味着程序现在知道:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

现在我们继续计算 f!!6

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

这意味着程序现在知道:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

现在我们继续计算 f!!12

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

这意味着程序现在知道:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

所以计算是相当懒惰的。程序知道某个值为f !! 8存在,并且它等于g 8,但不知道g 8是什么。


谢谢您提供这个问题。您如何创建和使用二维解空间?那会是一个列表的列表吗?以及 g n m = (something with) f!!a!!b - vikingsteve
1
当然可以。不过,为了得到一个真正的解决方案,我可能会使用记忆化库,比如memocombinators - rampion
很遗憾,它的时间复杂度是O(n^2)。 - Qumeric

9

正如Edward Kmett所说,为了加快速度,您需要缓存昂贵的计算并能够快速访问它们。

为了保持函数的非单调性,构建一个无限惰性树的解决方案,以及适当的索引方式(如前面的帖子所示)实现了这一目标。如果您放弃函数的非单调性,可以使用Haskell中提供的标准关联容器与“类似状态”的单子(例如State或ST)相结合。

主要缺点是您得到了一个非单调函数,但是您不再需要自己对结构进行索引,而可以直接使用关联容器的标准实现。

要做到这一点,首先需要重写函数以接受任何类型的单子:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

对于您的测试,您仍然可以使用Data.Function.fix定义一个不进行记忆化的函数,尽管它略微冗长:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

然后你可以将State monad与Data.Map结合使用,以加速处理速度:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

通过一些小的修改,您可以使代码适用于 Data.HashMap:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

除了持久化数据结构,您还可以尝试使用可变数据结构(如Data.HashTable)与ST monad相结合:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

与没有任何记忆化实现相比,这些实现方案允许您在处理大量输入时以微秒级别获得结果,而不必等待几秒钟。
使用Criterion作为基准测试,我发现使用Data.HashMap的实现比Data.Map和Data.HashTable的实现略好(约20%),其中Data.Map和Data.HashTable的计时非常相似。
基准测试的结果有点令人惊讶。我的初步感觉是HashTable会比HashMap实现更好,因为它是可变的。但是可能存在一些隐藏在最后一个实现中的性能缺陷。

3
GHC在优化不可变结构方面做得非常好。从C语言的直觉并不总是适用的。 - John Tyree

9
这是对Edward Kmett所提供的优秀答案的补充。
当我尝试他的代码时,定义中的nats和index似乎很神秘,因此我写了一个我认为更容易理解的替代版本。
我通过index'和nats'来定义index和nats。
index' t n在范围[1..]内进行定义。 (请回忆一下,index t在范围[0..]内进行定义。) 它通过将n视为一串位并反向读取这些位来搜索树。 如果该位为1,则采用右侧分支。 如果该位为0,则采用左侧分支。 当它到达最后一位(必须是1)时停止。
index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

正如nats针对index定义一样,使得index nats n == n始终为真,nats'针对index'进行定义。

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

现在,natsindex 简单地是 nats'index',但值向右移了一个位置:
index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'

谢谢。我正在记忆化一个多元函数,这确实帮助我弄清楚索引和 nats 到底在做什么。 - Kittsil

4
几年后,我看到这个问题并意识到可以使用zipWith和一个辅助函数在线性时间内进行记忆化处理:
dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilate具有方便的属性,即dilate n xs !! i == xs !! div i n

因此,假设我们已经知道了 f(0),这就简化了计算过程:

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

看起来非常像我们最初的问题描述,并且给出了线性解决方案(sum $ take n fs将需要O(n)的时间)。


2
所以这是一个生成(corecursive?)或动态编程的解决方案。每个生成值都需要O(1)时间,就像通常的斐波那契数列一样。太好了!而EKMETT的解决方案就像对数级别的大斐波那契数列,可以更快地得到大数字,跳过许多中间值。这样说对吗? - Will Ness
或许更接近于汉明数的那个,它有三个回溯指针指向正在生成的序列,并且每个指针在其上以不同的速度前进。非常漂亮。 - Will Ness

2

一种不需要索引且不基于 Edward KMETT 的解决方案。

我将共同的子树分离出来,放到一个共同的父节点中(f(n/4)f(n/2)f(n/4) 之间共享,f(n/6)f(2)f(3) 之间共享)。将它们作为单个变量保存在父节点中,这样子树的计算只需要进行一次。

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

该代码不容易扩展到通用的记忆化函数(至少,我不知道怎么做),而且你必须考虑子问题如何重叠,但该策略应适用于多个非整数参数。(我是为两个字符串参数设计的。)
每次计算后都会丢弃备忘录。(同样,我考虑的是两个字符串参数。)
我不知道这是否比其他答案更有效率。从技术上讲,每次查找只需要一两步(“看看你的孩子或你孩子的孩子”),但可能会使用大量额外的内存。
编辑:此解决方案尚不正确。共享是不完整的。
编辑:现在应该正确地共享了子节点,但我意识到这个问题有很多非平凡的共享:n/2/2/2和n/3/3可能相同。这个问题不适合我的策略。

2
Edward Kmett的回答又有了一个补充:一个自包含的例子:

仍然与Edward Kmett的回答相同:一个自包含的例子:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

以下是使用方法,将函数记忆化为单个整数参数(例如斐波那契数列):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

仅会缓存非负参数的值。

如果需要缓存负参数的值,请使用以下定义的memoInt

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

要为具有两个整数参数的函数缓存值,请使用memoIntInt,定义如下:

memoIntInt f = memoInt (\n -> memoInt (f n))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接