在Haskell中实现Floyd-Warshall算法

3

我正在学习 Floyd-Warshall 算法。现在已经成功地在 Haskell 中实现了它,我的实现方式类似于在命令式语言中的实现方式(也就是说,使用列表的列表来模拟二维数组),但这样做效率非常低下,因为在列表中访问一个元素要比在数组中慢得多。

在 Haskell 中有更聪明的方法吗?我曾尝试通过连接一些列表来实现,但总是失败。

我的代码:

floydwarshall :: [[Weight]] -> [[Weight]]
floydwarshall lst = fwAlg 1 $ initMatrix 0 $ list2matrix lst

fwAlg :: Int -> [[Weight]] -> [[Weight]]
fwAlg k m | k < rows m = let n = rows m
                             m' = foldl (\m (i,j) -> updateDist i j k m) m [(i,j) | i <- [0..n-1], j <- [0..n-1]]
                        in fwAlg (k+1) m'
          | otherwise = m

-- a special case where k is 0
initMatrix :: Int -> [[Weight]] -> [[Weight]]
initMatrix n m = if n == rows m then m else initMatrix (n+1) $ updateAtM 0.0 (n,n) m

updateDist :: Int -> Int -> Int -> [[Weight]] -> [[Weight]]
updateDist i j k m =
    let w = min (weight i j m) (weight i k m + weight k j m)
    in updateAtM w (i, j) m

weight :: Vertice -> Vertice -> [[Weight]] -> Weight
weight i j m = let Just w = elemAt (i, j) m in w

顺便说一下,我的实现可能不是100%正确的;我在i和j的边界情况下遇到了一些问题。 - Chien
3
我无法想到一种使用列表高效实现 F-W(Floyd-Warshall)算法的方法。可能最好使用一个 STUArray s (Int, Int) Int 作为中间高效表示,并在最后使用 runSTUArray 方法获得非 ST 值。这种方式会更像命令式编程风格,但有时候没有更优雅的高效替代方案。 - chi
1
问题是列表不是数组,那为什么不使用数组呢? - Daniel Wagner
2个回答

1
该算法具有规律的访问模式,因此我们可以避免大量索引并仍然使用列表编写它,具有与命令式版本相同的渐近性能。如果您确实想要使用数组以获得更快的速度,则可能仍希望对行和列进行批量操作,而不是读取和写入单个单元格。
-- Let's have a type for weights.  We could use Maybe but the ordering
-- behaviour is wrong - when there's no weight it should be like
-- +infinity.
data Weight = Weight Int | None deriving (Eq, Ord, Show)

addWeights :: Weight -> Weight -> Weight
addWeights (Weight x) (Weight y) = Weight (x + y)
addWeights _ _ = None

-- the main function just steps the matrix a number of times equal to
-- the node count.  Also pass along k at each step.
floydwarshall :: [[Weight]] -> [[Weight]]
floydwarshall m = snd (iterate step (0, m) !! length m)

-- step takes k and the matrix for k, returns k+1 and the matrix for
-- k+1.
step :: (Int, [[Weight]]) -> (Int, [[Weight]])
step (k, m) = (k + 1, zipWith (stepRow ktojs) istok m)
  where
    ktojs = m !! k  -- current k to each j
    istok = transpose m !! k  -- each i to current k

-- Make shortest paths from one i to all j.
-- We need the shortest paths from the current k to all j
-- and the shortest path from this i to the current k
-- and the shortest paths from this i to all j
stepRow :: [Weight] -> Weight -> [Weight] -> [Weight]
stepRow ktojs itok itojs = zipWith stepOne itojs ktojs
  where
    stepOne itoj ktoj = itoj `min` (itok `addWeights` ktoj)

-- example from wikipedia for testing
test :: [[Weight]]
test = [[Weight 0, None, Weight (-2), None],
        [Weight 4, Weight 0, Weight 3, None],
        [None, None, Weight 0, Weight 2],
        [None, Weight (-1), None, Weight 0]]

0

我不知道如何实现最佳性能,但我可以给你一些关于使代码抽象化的提示,这样你就可以更轻松地进行性能调优。

首先,如果您改变数据类型时不必重写所有内容,那将是很好的。现在,您已经使所有内容都具体化为列表的列表,因此让我们看看是否可以将其抽象化。首先,我们必须确定您的最小矩阵接口是什么。浏览您的代码,您似乎有initMatrixlist2matrixrowselemAtupdateAtM。这些是查询或修改矩阵的函数,这些是您需要实现的内容,以便为不同的矩阵类型创建此代码的新版本。

组织此接口的一种方法是将其制作成一个类。例如:

class Matrix m where
  list2matrix :: [[a]] -> m a
  matrix2List :: m a -> [[a]]
  rows :: m a -> Int
  elemAt :: Int -> Int -> m a -> a
  updateAtM :: a -> (Int, Int) -> m a -> m a
  setDiag :: a -> m a -> m a

我已经添加了一个matrix2List函数来提取您的结果,并将initMatrix重命名/修改为setDiag,这感觉更通用一些。

然后我们可以更新您的代码以使用这个新类:

floydwarshall :: Matrix m => [[Weight]] -> m Weight
floydwarshall lst = fwAlg 1 $ initMatrix $ list2matrix lst

fwAlg :: Matrix m => Int -> m Weight -> m Weight
fwAlg k m | k < rows m = let n = rows m
                             m' = foldl (\m (i,j) -> updateDist i j k m) m [(i,j) | i <- [0..n-1], j <- [0..n-1]]
                        in fwAlg (k+1) m'
          | otherwise = m

initMatrix :: Matrix m => m Weight -> m Weight
initMatrix = setDiag 0

updateDist :: Matrix m => Int -> Int -> Int -> m Weight -> m Weight
updateDist i j k m =
    let w = min (elemAt i j m) (elemAt i k m + elemAt k j m)
    in updateAtM w (i, j) m

dist :: Matrix m => Int -> Int -> Int -> m Weight -> Weight
dist i j 0 m = elemAt i j m
dist i j k m = min (dist i j (k-1) m) (dist i k (k-1) m + dist k j (k-1) m)

现在我们所需要做的就是开始定义一些矩阵类型并观察性能如何!让我们从列表开始,因为您已经完成了这项工作。我们将不得不使用newtype包装器来使GHC满意,但是忽略包装和解包装,这与您编写的代码在道德上是相同的:
newtype ListMatrix a = ListMatrix { getListMatrix :: [[a]] }

instance Matrix ListMatrix where
  list2matrix = ListMatrix
  matrix2List = getListMatrix
  rows = length . getListMatrix
  elemAt i j (ListMatrix m) = m !! i !! j
  updateAtM a (i,j) (ListMatrix m) =
    let (firstRows, row:laterRows) = splitAt i m
        (firstCols, _:laterCols) = splitAt j row
    in ListMatrix $ firstRows <> ((firstCols <> (a:laterCols)):laterRows)
  setDiag x = go 0
    where go n m = if n == rows m then m else go (n+1) $ updateAtM x (n,n) m

(此外,我填写了elemAtupdateAtM。)你应该可以运行

matrix2List @ListMatrix $ floydwarshall myList

并获得您当前拥有的相同结果(和性能)。

现在,开始实验!我们只需要定义Matrix的新实例并观察发生了什么。也许我们应该尝试纯函数:

data FunMatrix a = FunMatrix { size :: Int, getFunMatrix :: Int -> Int -> a }

instance Matrix FunMatrix where
  list2matrix l = FunMatrix (length l) (\i j -> l !! i !! j)
  matrix2List (FunMatrix s f) = (\i -> f i <$> [0..s-1]) <$> [0..s-1]
  rows = size
  elemAt i j m = getFunMatrix m i j
  updateAtM a (i,j) (FunMatrix s f) = FunMatrix s (\i' j' -> if i==i' && j==j' then a else f i' j')
  setDiag x (FunMatrix s f) = FunMatrix s (\i j -> if i==j then x else f i j)

那个函数表现如何?一个问题是起始查找函数仍然只是索引到列表的列表中,这很慢。一个解决方法是先转换为数组或向量,然后再进行索引。因为我们已经很好地抽象化了所有内容,所以唯一需要更改的就是在此处定义 list2matrix,然后您可能会获得良好的性能提升!


关于性能问题,还有一个需要注意的地方。在定义dist时使用了一些严重的“动态规划”。如果你直接写入和读取数组,这可能没问题,但在递归形式中,你可能会做很多重复的工作。解决方法是进行记忆化处理。Memoize。我的记忆化处理包是MemoTrie,它使记忆化处理变得非常容易。在这种情况下,你可以将dist更改为:

dist :: Matrix m => m Weight -> Int -> Int -> Int -> Weight
dist m = go'
  where
    go' = memo3 go
    go i j 0 = elemAt i j m
    go i j k = min (go' i j (k-1)) (go' i k (k-1) + go' k j (k-1))

这可能会给你一些提升!


你可能想要考虑采纳 @Chi 的建议,使用 STUArray,但是你会遇到一个问题: STUArray 接口要求数组查找在单子中进行。仍然可以使用我展示的抽象方法,但是你需要改变函数的类型。并且,因为你在接口中更改了类型,所以需要将算法代码更新为单子化的代码。这可能有点麻烦,但是为了获得最佳性能可能是必要的。

谢谢你的工作,那是相当多的代码!关于动态部分,我考虑过但不敢使用。我尝试使用类似 (dists i j k m) !! (k-1) 的代码,其中 dists :: [Weight],但不确定哪个更快。你能帮忙一下吗?还是我应该在另一个问题中问这个? - Chien
抱歉,我不确定你在这里在问什么。 - DDub
没问题,还是谢谢你的帮助!:) - Chien
这看起来像是Java。 - Random dude

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接