在Haskell中的缓存和显式并行性

6

我目前正在尝试优化我在Projet Euler上解决问题14的解决方案。我非常喜欢Haskell,并认为它非常适合这些问题,以下是我尝试过的三种不同的解决方案:

import Data.List (unfoldr, maximumBy)
import Data.Maybe (fromJust, isNothing)
import Data.Ord (comparing)
import Control.Parallel

next :: Integer -> Maybe (Integer)
next 1 = Nothing
next n
  | even n = Just (div n 2)
  | odd n  = Just (3 * n + 1)

get_sequence :: Integer -> [Integer]
get_sequence n = n : unfoldr (pack . next) n
  where pack n = if isNothing n then Nothing else Just (fromJust n, fromJust n)

get_sequence_length :: Integer -> Integer
get_sequence_length n
    | isNothing (next n) = 1
    | otherwise = 1 + (get_sequence_length $ fromJust (next n))

-- 8 seconds
main1 = print $ maximumBy (comparing length) $ map get_sequence [1..1000000]

-- 5 seconds
main2 = print $ maximum $ map (\n -> (get_sequence_length n, n)) [1..1000000]

-- Never finishes
main3 = print solution
  where
    s1 = maximumBy (comparing length) $ map get_sequence [1..500000]
    s2 = maximumBy (comparing length) $ map get_sequence [500001..10000000]
    solution = (s1 `par` s2) `pseq` max s1 s2

现在,如果您看一下实际问题,就会发现有很多缓存的潜力,因为大多数新序列将包含先前已计算过的子序列。
相比之下,我也用C写了一个版本:
使用缓存的运行时间:0.03秒
不使用缓存的运行时间:0.3秒 这简直太疯狂了!当然,缓存将时间减少了10倍,但即使没有缓存,它仍然比我的Haskell代码快至少17倍。
我的代码有什么问题? 为什么Haskell不为我缓存函数调用?由于函数是纯的,缓存不应该很简单吗,只是一个可用内存的问题?
我的第三个并行版本有什么问题?为什么它无法完成?
关于Haskell作为一种语言,编译器是否自动并行化某些代码(折叠、映射等),还是必须始终明确地使用Control.Parallel来完成?
编辑:我偶然发现了this类似的问题。他们提到他的函数不是尾递归的。我的get_sequence_length是尾递归的吗?如果不是,我该如何使其成为尾递归? 编辑2:
致Daniel:
非常感谢您的回复,真的很棒。我尝试了您的改进,发现有一些严重问题。
我使用Windows 7 (64位),3.3 GHZ四核处理器和8GB RAM测试。第一件事就是按照你说的将所有的Integer替换为Int,但每当我运行任何主程序时都会用光内存,即使+RTS kSize -RTS设置的非常高。
最终我找到了this(stackoverflow太神奇了……), 这意味着由于Windows上的所有Haskell程序都是以32位运行的,因此Ints会溢出从而导致无限递归,简直不可思议……
我在Linux虚拟机中运行了测试(具有64位ghc),结果相似。

你的 main3 中多了一个零... - Daniel Wagner
2个回答

20

好的,让我们从头开始。 第一件重要的事情是给出您用于编译和运行的确切命令行; 对于我的回答,我将使用此行针对所有程序的时序:

ghc -O2 -threaded -rtsopts test && time ./test +RTS -N

接下来:由于不同机器的计时时间差异很大,我们将为我的机器和你的程序提供一些基准计时。以下是我计算机上uname -a的输出:

Linux sorghum 3.4.4-2-ARCH #1 SMP PREEMPT Sun Jun 24 18:59:47 CEST 2012 x86_64 Intel(R) Core(TM)2 Quad CPU Q6600 @ 2.40GHz GenuineIntel GNU/Linux

重点是:四核、2.4GHz、64位。

使用 main1: 30.42s 用户 2.61s 系统 149% CPU 22.025 总计
使用 main2: 21.42s 用户 1.18s 系统 129% CPU 17.416 总计
使用 main3: 22.71s 用户 2.02s 系统 220% CPU 11.237 总计

实际上,我通过两种方式修改了 main3 :首先,在 s2 的范围末尾删除了一个零,其次将 max s1 s2 改为 maximumBy (comparing length) [s1, s2],因为前者仅偶然计算出正确的答案。 =)

现在我将专注于串行速度。(回答你的直接问题之一:不,GHC 不会自动并行化或记忆化你的程序。这两件事都有很难估算的开销,因此很难确定何时执行它们会有益。我不知道为什么即使在这个答案中的串行解决方案也能获得 >100% 的 CPU 利用率;也许在另一个线程中正在发生一些垃圾回收之类的事情。)我们将从 main2 开始,因为它是两个串行实现中更快的一个。获得一点提升的最便宜的方法是将所有类型签名从 Integer 更改为 Int

使用 Int: 11.17s 用户 0.50s 系统 129% CPU 8.986 总计(大约快了两倍)

下一次提升来自于减少内部循环中的分配(消除中间的 Maybe 值)。

import Data.List
import Data.Ord

get_sequence_length :: Int -> Int
get_sequence_length 1 = 1
get_sequence_length n
    | even n = 1 + get_sequence_length (n `div` 2)
    | odd  n = 1 + get_sequence_length (3 * n + 1)

lengths :: [(Int,Int)]
lengths = map (\n -> (get_sequence_length n, n)) [1..1000000]

main = print (maximumBy (comparing fst) lengths)

使用以下命令:4.84s user 0.03s system 101% cpu 4.777 total

接下来的提升来自于使用比evendiv更快的操作:

import Data.Bits
import Data.List
import Data.Ord

even' n = n .&. 1 == 0

get_sequence_length :: Int -> Int
get_sequence_length 1 = 1
get_sequence_length n = 1 + get_sequence_length next where
    next = if even' n then n `quot` 2 else 3 * n + 1

lengths :: [(Int,Int)]
lengths = map (\n -> (get_sequence_length n, n)) [1..1000000]

main = print (maximumBy (comparing fst) lengths)

使用以下命令:1.27s user 0.03s system 105% cpu 1.232 total

对于在家跟进的人,这比我们开始使用的main2快了约17倍——与切换到C相当竞争。

对于记忆化,有一些选择。最简单的方法是使用现有包,如data-memocombinators创建一个不可变数组并从中读取。选择此数组的大小对计时非常敏感;对于这个问题,我发现50000是一个相当好的上限。

import Data.Bits
import Data.MemoCombinators
import Data.List
import Data.Ord

even' n = n .&. 1 == 0

pre_length :: (Int -> Int) -> (Int -> Int)
pre_length f 1 = 1
pre_length f n = 1 + f next where
    next = if even' n then n `quot` 2 else 3 * n + 1

get_sequence_length :: Int -> Int
get_sequence_length = arrayRange (1,50000) (pre_length get_sequence_length)

lengths :: [(Int,Int)]
lengths = map (\n -> (get_sequence_length n, n)) [1..1000000]

main = print (maximumBy (comparing fst) lengths)

使用这个:0.53秒用户0.10秒系统149% CPU 0.421总计

最快的方法是为备忘录位使用可变的未装箱数组。这不太惯用,但速度非常快。速度对该数组的大小不太敏感,只要该数组的大小与您要获取答案的最大事物的大小大致相同即可。

import Control.Monad
import Control.Monad.ST
import Data.Array.Base
import Data.Array.ST
import Data.Bits
import Data.List
import Data.Ord

even' n = n .&. 1 == 0
next  n = if even' n then n `quot` 2 else 3 * n + 1

get_sequence_length :: STUArray s Int Int -> Int -> ST s Int
get_sequence_length arr n = do
    bounds@(lo,hi) <- getBounds arr
    if not (inRange bounds n) then (+1) `fmap` get_sequence_length arr (next n) else do
        let ix = n-lo
        v <- unsafeRead arr ix
        if v > 0 then return v else do
            v' <- get_sequence_length arr (next n)
            unsafeWrite arr ix (v'+1)
            return (v'+1)

maxLength :: (Int,Int)
maxLength = runST $ do
    arr <- newArray (1,1000000) 0
    writeArray arr 1 1
    loop arr 1 1 1000000
    where
    loop arr n len 1  = return (n,len)
    loop arr n len n' = do
        len' <- get_sequence_length arr n'
        if len' > len then loop arr n' len' (n'-1) else loop arr n len (n'-1)

main = print maxLength

使用以下代码:0.16s user 0.02s system 138% cpu 0.130 total(与记忆化 C 版本相当)


进展很好,最终结果也很不错。整个优化顺序在这一点上感觉已经被编码了。编辑:一个问题,为什么你使用Array而不是Vector?这是个人偏好,但我就是无法忍受Array的接口。 - Thomas M. DuBuisson
非常感谢,你的回答真的很直接明了。然而我不明白的是,你的第一个代码示例是如何消除子列表的。长度函数难道不只是顺序运行get_sequence_length吗?我不明白它与原始main2有什么不同,除了部分代码被拆分到长度函数中。(另外,请查看我的编辑以获取更长的回复) - user1599468
@user1599468 哎呀,32位的问题有点烦人。至于消除列表 - 你是对的,我没有说得很准确。我会尽快更新我的答案,但简短的回答是,在每次循环迭代期间,它消除了两个JustNothing值的分配。 - Daniel Wagner
1
如果你想追求原始速度,你应该用 n `shiftR` 1 替换 n `quot` 2。在我的电脑上,这样会快很多。此外,数组版本(我试过的唯一一个)比非线程版本更快。最后,通过避免使用 getBoundsinRangelet ix = n - lo,从索引0开始数组,并将上限作为参数传递给 get_sequence_length,这样你就可以直接在 n 处比较 hiunsafeRead/Write - Daniel Fischer

0

GHC不会自动并行化任何内容。正如你猜测的那样,get_sequence_length不是尾递归的。请参见这里。考虑编译器(除非它为您执行一些好的优化)无法评估所有这些递归加法直到您达到末尾; 您正在“构建thunks”,这通常不是一个好事情。

相反,尝试调用递归辅助函数并传递累加器,或尝试根据foldr定义它。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接