在Haskell中高效地查找多个列表中的最大值

6
我正在编写一个算法来寻找一条路径上的多个转折点,给出了一个描述路径的坐标列表。这种动态规划算法在O(kn^2)内运行得很好,其中k是转折点的数量,n是点数。简而言之:最慢的部分是两个坐标之间的距离计算;该算法要求同一对点进行'k'次重新计算。备忘录不是一个选择(点太多)。可以“倒置”算法-但以某种方式,倒置的算法非常缓慢,在Haskell中消耗过多的内存。
我觉得问题如下; 给定一个大小固定的数组数组(加上一些动态计算的值-例如,这将是将值与列表一起压缩的结果:
arr = [ (2, [10,5,12]), (1, [2,8, 20]), (4, [3, 2, 10]) ]

我正在尝试找到列表元素与固定值的最大值:

[12, 9, 21]

我正在做的事情 - 类似于:

foldl' getbest (replicate 3 0) arr
getbest acc (fixval, item) = map comparator $ zip acc item
comparator orig new
    | new + fixval > orig = new + fixval
    | otherwise = orig

问题在于每次调用“getbest”都会建立一个新的“acc”,这是n ^ 2,非常耗费资源。分配内存很昂贵,这可能就是问题所在。你有没有想过如何高效地解决这个问题?
为了清晰起见,以下是该函数的实际代码:
dynamic2FreeFlight :: Int -> [ Coord ] -> [ Coord ]
dynamic2FreeFlight numpoints points = reverse $ (dsCoord bestPoint) : (snd $ (dsScore bestPoint) !! (numpoints - 2))
    where
        bestPoint :: DSPoint
        bestPoint = maximumBy (\x y -> (getFinalPointScore x) `compare` (getFinalPointScore y)) compresult

        getFinalPointScore :: DSPoint -> Double
        getFinalPointScore sc = fst $ (dsScore sc) !! (numpoints - 2)

        compresult :: [ DSPoint ]
        compresult = foldl' onestep [] points 

        onestep :: [ DSPoint ] -> Coord -> [ DSPoint ]
        onestep lst point = (DSPoint point (genmax lst)) : lst
            where
                genmax :: [ DSPoint ] -> [ (Double, [ Coord ]) ]
                genmax lst = map (maximumBy comparator) $ transpose prepared
                comparator a b = (fst a) `compare` (fst b)
                distances :: [ Double ]
                distances = map (distance point . dsCoord) lst
                prepared :: [ [ (Double, [ Coord ]) ] ]
                prepared 
                    | length lst == 0 = [ replicate (numpoints - 1) (0, []) ]
                    | otherwise = map prepare $ zip distances lst
                prepare :: (Double, DSPoint) -> [ (Double, [ Coord ]) ]
                prepare (dist, item) = (dist, [dsCoord item]) : map addme (take (numpoints - 2) (dsScore item))
                    where
                        addme (score, coords) = (score + dist, dsCoord item : coords)

2
[a,b,c] 不是一个数组,而是一个(单向)链表。 - sepp2k
[12, 9, 21] 来自哪里? - Gabe
12是“第一项+固定数字”(即10 + 2)的最大值,9是“第二项+固定数字(8 + 1)”等。 - ondra
4个回答

5

使用以下方法来对 Travis Browns、SCLV、Kennys 以及你的答案进行基准测试:

import Data.List
import Criterion.Main
import Criterion.Config
import qualified Data.Vector as V

-- Vector based solution (Travis Brown)
bestVector :: V.Vector (V.Vector Int) -> V.Vector Int -> V.Vector Int
bestVector = (V.foldl1' (V.zipWith max) .) . (V.zipWith . flip $ V.map . (+))

convertVector :: [[Int]] -> V.Vector (V.Vector Int)
convertVector = V.fromList . map V.fromList

arrVector = convertVector arr
valVector = V.fromList  val :: V.Vector Int

-- Shared arr and val
arr = [map (x*) [1, 2.. 2000] | x <- [1..1000]]
val = [1..1000]

-- SCLV solution
bestSCLV = foldl' (zipWith max) (repeat 0) . map (\(fv,xs) -> map (+fv) xs)

-- KennyTM Solution
bestKTM arr = map maximum $ transpose [ map (a+) bs | (a,bs) <- arr]

-- Original
getbest :: [Int] -> (Int, [Int]) -> [Int]
getbest acc (fixval, item) = map (uncurry comparator) $ zip acc item
 where
  comparator o n = max (n + fixval) o

someFuncOrig = foldl' getbest acc
  where acc = replicate 2000 0

-- top level functions
someFuncVector :: (V.Vector (V.Vector Int), V.Vector Int) -> V.Vector Int
someFuncVector = uncurry bestVector
someFuncSCLV = bestSCLV
someFuncKTM = bestKTM

main = do
  let vec = someFuncVector (arrVector, valVector) :: V.Vector Int
  print (someFuncOrig (zip val arr) == someFuncKTM (zip val arr)
        , someFuncKTM (zip val arr) == someFuncSCLV (zip val arr)
        , someFuncSCLV (zip val arr) == V.toList vec)
  defaultMain
        [ bench "someFuncVector" (whnf someFuncVector (arrVector, valVector))
        , bench "someFuncSCLV"   (nf someFuncSCLV (zip val arr))
        , bench "someFuncKTM"    (nf someFuncKTM (zip val arr))
        , bench "original"       (nf someFuncOrig (zip val arr))
        ]

也许我的基准测试出了问题,但结果相当令人失望。
向量:379.0164毫秒(密度太低了-到底是怎么回事?) SCLV:207.5399毫秒 肯尼:200.6028毫秒 原始数据:138.4270毫秒
[tommd@Mavlo Test]$ ./t
(True,True,True)
warming up
estimating clock resolution...
mean is 13.65277 us (40001 iterations)
found 3378 outliers among 39999 samples (8.4%)
  1272 (3.2%) high mild
  2106 (5.3%) high severe
estimating cost of a clock call...
mean is 1.653858 us (58 iterations)
found 3 outliers among 58 samples (5.2%)
  2 (3.4%) high mild
  1 (1.7%) high severe

benchmarking someFuncVector
collecting 100 samples, 1 iterations each, in estimated 54.56119 s
bootstrapping with 100000 resamples
mean: 379.0164 ms, lb 357.0403 ms, ub 401.0113 ms, ci 0.950
std dev: 112.6714 ms, lb 101.8206 ms, ub 125.4846 ms, ci 0.950
variance introduced by outliers: 4.000%
variance is slightly inflated by outliers

benchmarking someFuncSCLV
collecting 100 samples, 1 iterations each, in estimated 20.92559 s
bootstrapping with 100000 resamples
mean: 207.5399 ms, lb 207.4099 ms, ub 207.8410 ms, ci 0.950
std dev: 955.1629 us, lb 507.1857 us, ub 1.937356 ms, ci 0.950
found 3 outliers among 100 samples (3.0%)
  2 (2.0%) high severe
variance introduced by outliers: 0.990%
variance is unaffected by outliers

benchmarking someFuncKTM
collecting 100 samples, 1 iterations each, in estimated 20.14799 s
bootstrapping with 100000 resamples
mean: 200.6028 ms, lb 200.5273 ms, ub 200.6994 ms, ci 0.950
std dev: 434.9564 us, lb 347.5326 us, ub 672.6736 us, ci 0.950
found 1 outliers among 100 samples (1.0%)
  1 (1.0%) high severe
variance introduced by outliers: 0.990%
variance is unaffected by outliers

benchmarking original
collecting 100 samples, 1 iterations each, in estimated 14.05241 s
bootstrapping with 100000 resamples
mean: 138.4270 ms, lb 138.2244 ms, ub 138.6568 ms, ci 0.950
std dev: 1.107366 ms, lb 930.6549 us, ub 1.381234 ms, ci 0.950
found 15 outliers among 100 samples (15.0%)
  7 (7.0%) low mild
  7 (7.0%) high mild
  1 (1.0%) high severe
variance introduced by outliers: 0.990%
variance is unaffected by outliers

切换到Vector的流融合版本在这个基准测试中极大地加速了我的代码(对我来说,从476.9359毫秒降至73.31412微秒(!))。只需要使用import qualified Data.Vector.Fusion.Stream as V并将V.Vector替换为V.Stream即可。 - Travis Brown
Travis: 我对你的测试的有效性表示怀疑。你可能让它评估了 whnf,所以它实际上并没有做任何工作 - 尝试使用 nf (V.toList . someFuncVector) 并等待足够长的时间,你会看到:collecting 100 samples, 1 iterations each, in estimated 4284.720 s。换句话说,如果我们推断(因为我不想等一个多小时),Stream 解决这个问题需要 42 秒(而不是 73 微秒)。 - Thomas M. DuBuisson
你是对的 - 对此我很抱歉。我应该知道这太美好而不现实了。 - Travis Brown
你是对的 - 你清理后的版本显然是最好的。不确定为什么,但我怀疑编译器无法手动执行某些融合操作。 - sclv
将加法推到max内部,特别是看起来真的很值得。顺便说一句,提高展开阈值有所帮助,但也对原始数据有帮助。 - sclv
sclv: 同意。我发现 (\(c,xs) -> map (+c) xs) 的预处理成本真的很高,而且似乎无法融合。我甚至尝试导入 Data.List.Stream 来实现更好的融合,虽然有所帮助,但效果并不太明显。 - Thomas M. DuBuisson

2

我还没有检查过效率,不过这个怎么样:

map maximum $ transpose [ map (a+) bs | (a,bs) <- arr]

? 由于结果是关于总和的,所以值与列表先相加。然后我们取列表的转置,使其成为按列主要的形式。最后,我们计算每列的最大值。(需要导入Data.List模块,顺便提一下)


我尝试过这种方式 - 不幸的是,它没有帮助:( 它只会消耗大量的内存。 - ondra

1
你可以尝试使用 Data.Vector
import qualified Data.Vector as V

best :: V.Vector (V.Vector Int) -> V.Vector Int -> V.Vector Int
best = (V.foldl1' (V.zipWith max) .) . (V.zipWith . flip $ V.map . (+))

convert :: [[Int]] -> V.Vector (V.Vector Int)
convert = V.fromList . map V.fromList

arr = convert [[10, 5, 12], [2, 8, 20], [3, 2, 10]]
val = V.fromList [2, 1, 4] :: V.Vector Int

这个有效:

*Main> best arr val
fromList [12,9,21] :: Data.Vector.Vector

是的,请查看我的答案下面的评论 - 关于这个问题,有一些对 Stream 性能的争议。 - Thomas M. DuBuisson

1
best = foldl' (zipWith max) (repeat 0) . map (\(fv,xs) -> map (+fv) xs)

像Kenny一样,我们首先添加。像你的一样,我们进行单次遍历,但使用zipWith max,我们可以更普遍和简洁地完成它。没有严格的基准测试,但这应该相当不错。


1
和之前提到的向量答案类似,你可以使用foldl1'并省略起始值repeat 0: best = foldl1' (zipWith max) . map (\(fv,xs) -> map (+fv) xs) - Travis Brown

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接