在Haskell中记忆化多维递归解法

7
我正在解决一个递归问题,使用haskell语言,虽然我已经得到了解决方案,但我想缓存子问题的输出结果,因为存在重叠子问题的属性。
问题是:给定一个n*m尺寸的网格和一个整数k,从(1,1)出发到达(n,m),最多改变k次方向,有多少种方式?
以下是没有记忆化的代码。
paths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
paths i j n m k dir
    | i > n || j > m || k < 0 = 0
    | i == n && j == m = 1
    | dir == 0 = paths (i+1) j n m k 1 + paths i (j+1) n m k 2        -- is in grid (1,1)
    | dir == 1 = paths (i+1) j n m k 1 + paths i (j+1) n m (k-1) 2    -- down was the direction took to reach here
    | dir == 2 = paths (i+1) j n m (k-1) 1 + paths i (j+1) n m k 2    -- right was the direction took to reach here 
    | otherwise = -1

这里的依赖变量是ijkdir。在像C ++ / Java这样的语言中,可以使用4-d DP数组(dp [n] [m] [k] [3]),但是在Haskell中,我找不到实现这种方式的方法。

2个回答

6
"绑定"是一种众所周知的技术,可让 GHC 运行时为您记忆结果,如果您事先知道需要查找的所有值。这个想法是将递归函数转换为自引用数据结构,然后简单地查找您实际关心的值。我选择使用数组,但映射也可以。在任何情况下,您使用的数组或映射必须是惰性/非严格的,因为我们将插入尚未准备好计算的值,直到整个数组填充为止。
import Data.Array (array, bounds, inRange, (!))

paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, 0)
  where go (i, j, k, dir)
          | i == m && j == n = 1
          | dir == 1 = get (i+1, j, k, 1) + get (i, j+1, k-1, 2)    -- down was the direction took to reach here
          | dir == 2 = get (i+1, j, k-1, 1) + get (i, j+1, k, 2)    -- right was the direction took to reach here
          | otherwise = get (i+1, j, k, 1) + get (i, j+1, k, 2)     -- is in grid (1,1)
        a = array ((1, 1, 0, 1), (m, n, k, 2))
            [(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [1..2]]
        get x | inRange (bounds a) x = a ! x
              | otherwise = 0

我稍微简化了你的API:
  • mn参数在每次迭代中不会改变,因此它们不应该是递归调用的一部分。
  • 客户端不应该告诉您ijdir从哪里开始,因此它们已从函数签名中删除,并隐式地从1、1和0开始。
  • 我还交换了mn的顺序,因为首先取一个n参数有点奇怪。这让我头疼了很久,因为我没有注意到我还需要更改基本情况!

然后,正如我之前所说,想法是用所有我们需要进行的递归调用填充数组: 这就是array调用。请注意,array中的单元格通过调用go进行初始化,这(除了基本情况!)涉及调用get,这涉及查找数组中的元素。以这种方式,a具有自引用或递归性质。但是我们不必决定按什么顺序查找或按什么顺序插入:我们足够懒惰,以便GHC根据需要评估数组元素。

我还有点聪明,只为dir=1dir=2留出了数组的空间,而不是dir=0。我做到了这一点,因为dir=0只在第一次调用时发生,并且我可以直接调用go处理该情况,绕过get中的边界检查。这个技巧意味着如果你传递一个小于1的mn,或者k小于零,你将会得到运行时错误。如果需要处理该情况,则可以为paths本身添加一个保护程序。

当然,它确实起作用:

> paths 3 3 2
4

还有一件事情,您可以做的是使用真实的数据类型作为方向,而不是使用 Int

import Data.Array (Ix, array, bounds, inRange, (!))
import Prelude hiding (Right)

data Direction = Neutral | Down | Right deriving (Eq, Ord, Ix)

paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, Neutral)
  where go (i, j, k, dir)
          | i == m && j == n = 1
          | otherwise = case dir of
            Neutral -> get (i+1, j, k, Down) + get (i, j+1, k, Right)
            Down -> get (i+1, j, k, Down) + get (i, j+1, k-1, Right)
            Right -> get (i+1, j, k-1, Down) + get (i, j+1, k, Right)
        a = array ((1, 1, 0, Down), (m, n, k, Right))
            [(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [Down, Right]]
        get x | inRange (bounds a) x = a ! x
              | otherwise = 0

(我认为使用I和J可能比Down和Right更好,不确定哪个更容易记忆)。我认为这样做可能是一种改进,因为类型现在具有更多的含义,并且您不必处理像dir = 7 这样应该是非法的奇怪的 otherwise 子句。但是,它仍然有点奇怪,因为它依赖于枚举值的顺序:如果我们将 Neutral 放在 Down 和 Right 之间,它将会出错。(我尝试完全删除Neutral 方向并为第一步添加更多特殊情况,但这样做会变得很麻烦)

4
在Haskell中,这些问题并不是最简单的。事实上,您可能希望进行一些原地变异以节省内存和时间,因此我认为没有比配备可怕的ST单子更好的方法。
这可以在各种数据结构、数组、向量、repa张量中完成。我选择了hashtables中的HashTable,因为它最简单易用且在我的示例中具有足够的性能。
首先,介绍一下:
{-# LANGUAGE Rank2Types #-}
module Solution where

import Control.Monad.ST
import Control.Monad
import Data.HashTable.ST.Basic as HT

Rank2Types在处理ST时非常有用,因为使用幻影类型。我选择了哈希表的基本变体Basic,因为作者声称它具有最快的查找速度---而我们将要进行大量的查找。

建议为地图使用类型别名,所以这里我们可以这样做:

type Mem s = HT.HashTable s (Int, Int, Int, Int) Integer

ST-free入口只是为了创建地图并调用我们的怪物:

runpaths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
runpaths i j n m k dir = runST $ do
  mem <- HT.new
  paths mem i j n m k dir

这里是paths的记忆化计算。我们只需要在地图中搜索结果,如果结果不在地图中,则保存结果并返回:

mempaths mem i j n m k dir = do
  res <- HT.lookup mem (i, j, k, dir)
  case res of
    Just x -> return x
    Nothing -> do
      x <- paths mem i j n m k dir
      HT.insert mem (i, j, k, dir) x
      return x

这里是算法的核心部分,它只是一种单子动作,使用带有记忆化调用代替纯递归:

paths mem i j n m k dir
    | i > n || j > m || k < 0 = return 0
    | i == n && j == m = return 1
    | dir == 0 = do
        x1 <- mempaths mem (i+1) j n m k 1
        x2 <- mempaths mem i (j+1) n m k 2        -- is in grid (1,1)
        return $ x1 + x2
    | dir == 1 = do 
        x1 <- mempaths mem (i+1) j n m k 1
        x2 <- mempaths mem i (j+1) n m (k-1) 2    -- down was the direction took to reach here
        return $ x1 + x2
    | dir == 2 = do
        x1 <- mempaths mem (i+1) j n m (k-1) 1 
        x2 <- mempaths mem i (j+1) n m k 2    -- right was the direction took to reach here 
        return $ x1 + x2
    | otherwise = return (-1)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接