如何优化这段 Haskell 代码以在次线性时间内汇总素数?

18

欧拉计划的第10个问题是找出小于给定n的所有质数的总和。

我通过使用埃拉托色尼筛法生成的质数之和来解决它。然后,我发现了Lucy_Hedgehog的更高效的解决方案(次线性!)。

对于n = 2⋅10^9

  • Python代码(来自上面的引用)在Python 2.7.3中运行需要1.2秒。

  • C++代码(我的)在使用g++ 4.8.4编译后大约需要0.3秒。

我重新使用Haskell实现了相同的算法,因为我正在学习它:

import Data.List

import Data.Map (Map, (!))
import qualified Data.Map as Map

problem10 :: Integer -> Integer
problem10 n = (sieve (Map.fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]) 2 r vs) ! n
              where vs = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]
                    r  = floor (sqrt (fromIntegral n))

sieve :: Map Integer Integer -> Integer -> Integer -> [Integer] -> Map Integer Integer
sieve m p r vs | p > r     = m
               | otherwise = sieve (if m ! p > m ! (p - 1) then update m vs p else m) (p + 1) r vs

update :: Map Integer Integer -> [Integer] -> Integer -> Map Integer Integer
update m vs p = foldl' decrease m (map (\v -> (v, sumOfSieved m v p)) (takeWhile (>= p*p) vs))

decrease :: Map Integer Integer -> (Integer, Integer) -> Map Integer Integer
decrease m (k, v) = Map.insertWith (flip (-)) k v m

sumOfSieved :: Map Integer Integer -> Integer -> Integer -> Integer
sumOfSieved m v p = p * (m ! (v `div` p) - m ! (p - 1))

main = print $ problem10 $ 2*10^9

我使用ghc -O2 10.hs编译并使用time ./10运行,它给出了正确的答案,但需要大约7秒钟。 我使用ghc -prof -fprof-auto -rtsopts 10编译并使用./10 +RTS -p -h运行。 10.prof显示decrease占用了52.2%的时间和67.5%的内存分配。 运行hp2ps 10.hp后,我得到了这样的堆配置文件:

hp

看起来 decrease 占用了大部分堆空间。 GHC 版本为 7.6.3。

你如何优化这段 Haskell 代码的运行时间?


更新于 13.06.17:

尝试hashtables 包中的可变 Data.HashTable.IO.BasicHashTable 替换不可变的 Data.Map,但可能我做错了什么,因为对于很小的 n = 30,它已经花费了大约10秒钟。有什么问题吗?

更新于 18.06.17:

对哈希表性能问题感到好奇是一篇不错的文章。我使用可变的Data.HashTable.ST.Linear采用了Sherh的代码, 但是换成了Data.Judy。它运行时间为1.1秒,仍然相对较慢。


1
请问您能否将算法拆分成合理大小的顶层函数,以便进行推理(查看类型等有助于...)?谢谢。 - Centril
1
将“Map”更改为“IntMap”。 - Thomas M. DuBuisson
1
@Centril 我将sieveupdatedecreasesumOfsieved分离出来了,希望这有所帮助。 - Adam Stelmaszczyk
1
你的算法有多少是通过修改Maps来支配的,又有多少是通过读取来支配的?如果你需要大量读取,那么使用http://hackage.haskell.org/package/vector-0.12.0.1/docs/Data-Vector-Unboxed.html对于数值代码非常好。还有http://hackage.haskell.org/package/array。 - Centril
1
@ThomasM.DuBuisson 谢谢,我也把Integer改成了Int,代码在这里。这将运行时间缩短到了5秒。 - Adam Stelmaszczyk
显示剩余3条评论
4个回答

7
我做了一些小改进,现在它在我的机器上运行时间为3.4-3.5秒。使用IntMap.Strict有很大帮助。除此之外,我只是手动执行了一些ghc优化,以确保将Haskell代码更接近于您链接中的Python代码。作为下一步,您可以尝试使用一些可变的HashMap。但我不确定...IntMap不能比一些可变容器更快,因为它是不可变的。尽管我仍然对它的效率感到惊讶。我希望这可以更快地实现。
以下是代码:
import Data.List (foldl')
import Data.IntMap.Strict (IntMap, (!))
import qualified Data.IntMap.Strict as IntMap

p :: Int -> Int
p n = (sieve (IntMap.fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]) 2 r vs) ! n
               where vs = [n `div` i | i <- [1..r]] ++ [n', n' - 1 .. 1]
                     r  = floor (sqrt (fromIntegral n) :: Double)
                     n' = n `div` r - 1

sieve :: IntMap Int -> Int -> Int -> [Int] -> IntMap Int
sieve m' p' r vs = go m' p'
  where
    go m p | p > r               = m
           | m ! p > m ! (p - 1) = go (update m vs p) (p + 1)
           | otherwise           = go m (p + 1)

update :: IntMap Int -> [Int] -> Int -> IntMap Int
update s vs p = foldl' decrease s (takeWhile (>= p2) vs)
  where
    sp = s ! (p - 1)
    p2 = p * p
    sumOfSieved v = p * (s ! (v `div` p) - sp)
    decrease m  v = IntMap.adjust (subtract $ sumOfSieved v) v m

main :: IO ()
main = print $ p $ 2*10^(9 :: Int) 

更新:

使用可变的hashtables,我已经成功将Haskell的性能提高到了约5.5秒,具体实现请见this implementation

此外,我在多个地方使用了非装箱向量而不是列表。线性哈希似乎是最快的。我认为这可以做得更快。我注意到hasthables包中有一个sse42选项,但不确定是否设置正确,即使没有它也能运行得很快。

更新2(2017年6月19日)

我已经成功地使它比@Krom的最佳解决方案(使用我的代码+他的映射)快了3倍,通过完全放弃Judy哈希映射。相反,只使用普通数组。如果您注意到S哈希映射的键要么是从1到n'的序列,要么是i从1到r的n div i,则可以想出相同的想法。因此,我们可以将这样的哈希映射表示为两个数组,在数组中进行查找取决于搜索键。
$ time ./judy
95673602693282040

real    0m0.590s
user    0m0.588s
sys     0m0.000s

我的代码+我的稀疏地图

$ time ./sparse
95673602693282040

real    0m0.203s
user    0m0.196s
sys     0m0.004s

如果使用已生成的向量和Vector库,以及将readArray替换为unsafeRead,甚至可以更快地完成此操作。但我认为只有在您真正希望尽可能优化时才应该这样做。

与这个解决方案进行比较是欺骗行为,是不公平的。我预计在Python和C++中实现相同的想法会更快。但是@Krom使用了闭合哈希映射的解决方案已经是欺骗了,因为它使用了自定义数据结构而不是标准数据结构。至少可以看出,在Haskell中,标准和最流行的哈希映射并不那么快。对于这种问题,使用更好的算法和更好的特定数据结构可能会更好。

这里是结果代码。


1
@AdamStelmaszczyk 你可以尝试阅读这篇博客文章。它可能会帮助你更快地提出命令式/可变的解决方案:https://www.reddit.com/r/haskell/comments/6e4wq8/imperative_haskell/ - Shersh
1
@AdamStelmaszczyk 我已经更新了答案,使用可变版本运行速度非常快。直接从您提供的链接中翻译了Python代码。 - Shersh
1
如果我尝试你的更新程序,它给出了错误的结果318504960 - typetetris
1
@Shersh,我赞赏你聪明的解决方案,并点赞了你的回答! - typetetris
1
我计时了你的稀疏映射 - 0.2秒,这是给出的最快的Haskell代码,做得好。其他人也非常有帮助,谢谢,但我无法分割赏金,所以我认为你应该收到它。出于好奇,我用C++ 实现 了你的想法,我只是将所有东西存储在一个数组S中,它运行时间为0.15秒。 - Adam Stelmaszczyk
显示剩余15条评论

4

首先作为基准,现有方法在我的机器上的时间:

  1. Original program posted in the question:

    time stack exec primorig
    95673602693282040
    
    real    0m4.601s
    user    0m4.387s
    sys     0m0.251s
    
  2. Second the version using Data.IntMap.Strict from here

    time stack exec primIntMapStrict
    95673602693282040
    
    real    0m2.775s
    user    0m2.753s
    sys     0m0.052s
    
  3. Shershs code with Data.Judy dropped in here

    time stack exec prim-hash2
    95673602693282040
    
    real    0m0.945s
    user    0m0.955s
    sys     0m0.028s
    
  4. Your python solution.

    I compiled it with

    python -O -m py_compile problem10.py
    

    and the timing:

    time python __pycache__/problem10.cpython-36.opt-1.pyc
    95673602693282040
    
    real    0m1.163s
    user    0m1.160s
    sys     0m0.003s
    
  5. Your C++ version:

    $ g++ -O2 --std=c++11 p10.cpp -o p10
    $ time ./p10
    sum(2000000000) = 95673602693282040
    
    real    0m0.314s
    user    0m0.310s
    sys     0m0.003s
    
我没有为slow.hs提供基准,因为我不想等待使用2*10^9作为参数运行时完成。
次秒级性能
下面的程序在我的机器上不到一秒钟就可以运行。
它使用手动编写的哈希映射表,采用闭散列和线性探测,并使用某种变体的Knuth哈希函数,请参见这里
当然,它在某种程度上是针对特定情况进行了优化,例如查找函数期望搜索键存在。
计时:
time stack exec prim
95673602693282040

real    0m0.725s
user    0m0.714s
sys     0m0.047s

首先,我实现了自己的哈希映射表,仅仅是用来散列键值对。

key `mod` size

我选择了一个比预期输入大多倍的尺寸,但程序需要22秒或更长时间才能完成。

最后,关键是选择适合工作负载的哈希函数。

以下是程序:

import Data.Maybe
import Control.Monad
import Data.Array.IO
import Data.Array.Base (unsafeRead)

type Number = Int

data Map = Map { keys :: IOUArray Int Number
               , values :: IOUArray Int Number
               , size :: !Int 
               , factor :: !Int
               }

newMap :: Int -> Int -> IO Map
newMap s f = do
  k <- newArray (0, s-1) 0
  v <- newArray (0, s-1) 0
  return $ Map k v s f 

storeKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
storeKey arr s f key = go ((key * f) `mod` s)
  where
    go :: Int -> IO Int
    go ind = do
      v <- readArray arr ind
      go2 v ind
    go2 v ind
      | v == 0    = do { writeArray arr ind key; return ind; }
      | v == key  = return ind
      | otherwise = go ((ind + 1) `mod` s)

loadKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
loadKey arr s f key = s `seq` key `seq` go ((key *f) `mod` s)
  where
    go :: Int -> IO Int
    go ix = do
      v <- unsafeRead arr ix
      if v == key then return ix else go ((ix + 1) `mod` s)

insertIntoMap :: Map -> (Number, Number) -> IO Map
insertIntoMap m@(Map ks vs s f) (k, v) = do
  ix <- storeKey ks s f k
  writeArray vs ix v
  return m

fromList :: Int -> Int -> [(Number, Number)] -> IO Map
fromList s f xs = do
  m <- newMap s f
  foldM insertIntoMap m xs

(!) :: Map -> Number -> IO Number
(!) (Map ks vs s f) k = do
  ix <- loadKey ks s f k
  readArray vs ix

mupdate :: Map -> Number -> (Number -> Number) -> IO ()
mupdate (Map ks vs s fac) i f = do
  ix <- loadKey ks s fac i
  old <- readArray vs ix
  let x' = f old
  x' `seq` writeArray vs ix x'

r' :: Number -> Number
r'  = floor . sqrt . fromIntegral

vs' :: Integral a => a -> a -> [a]
vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]  

vss' n r = r + n `div` r -1

list' :: Int -> Int -> [Number] -> IO Map
list' s f vs = fromList s f [(i, i * (i + 1) `div` 2 - 1) | i <- vs]

problem10 :: Number -> IO Number
problem10 n = do
      m <- list' (19*vss) (19*vss+7) vs
      nm <- sieve m 2 r vs
      nm ! n
    where vs = vs' n r
          vss = vss' n r
          r  = r' n

sieve :: Map -> Number -> Number -> [Number] -> IO Map
sieve m p r vs | p > r     = return m
               | otherwise = do
                   v1 <- m ! p
                   v2 <- m ! (p - 1)
                   nm <- if v1 > v2 then update m vs p else return m
                   sieve nm (p + 1) r vs

update :: Map -> [Number] -> Number -> IO Map
update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs

decrease :: Number -> Map -> Number -> IO Map
decrease p m k = do
  v <- sumOfSieved m k p
  mupdate m k (subtract v)
  return m

sumOfSieved :: Map -> Number -> Number -> IO Number
sumOfSieved m v p = do
  v1 <- m ! (v `div` p)
  v2 <- m ! (p - 1)
  return $ p * (v1 - v2)

main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9

我对哈希和相关技术不是专业的,所以这个可以被大大改进。也许我们Haskellers应该改进现成的哈希映射或提供一些更简单的。

我的哈希映射,Shershs代码

如果我把我的哈希映射插入到Shershs(见下面的答案)代码中,请参见此处,我们甚至降到了

time stack exec prim-hash2
95673602693282040

real    0m0.601s
user    0m0.604s
sys     0m0.034s

为什么slow.hs很慢?

如果你阅读Data.HashTable.ST.Basicinsert函数的源代码,你会发现它删除了旧的键值对并插入了一个新的键值对。它没有查找值的“位置”并进行更改,正如人们可能想象的那样,如果他们读到它是一个“可变”哈希表。在这里,哈希表本身是可变的,因此您不需要复制整个哈希表以插入新的键值对,但是对于键值对的值位置则不然。我不知道这就是slow.hs缓慢的整个故事,但我的猜测是,这是其中相当大的一部分。

一些小改进

这就是我第一次尝试改进程序时遵循的思路。

看,你不需要从键到值的可变映射。你的键集已经固定了。你需要一个从键到可变位置的映射。 (顺便说一下,默认情况下C++就是这样做的。)

因此,我尝试想出一个这样的解决方案。我首先使用了Data.IntMap.StrictData.IORef中的IntMap IORef,其时间为

tack exec prim
95673602693282040

real    0m2.134s
user    0m2.141s
sys     0m0.028s

我认为使用非装箱值可能会有所帮助。为此,我使用每个仅包含1个元素的IOUArray Int Int代替IORef进行测试,并得出以下结果:

time stack exec prim
95673602693282040

real    0m2.015s
user    0m2.018s
sys     0m0.038s

这两种方法并没有太大区别,所以我尝试使用unsafeReadunsafeWrite来消除1元素数组中的边界检查,并获得了以下时间:

time stack exec prim
95673602693282040

real    0m1.845s
user    0m1.850s
sys     0m0.030s

我使用 Data.IntMap.Strict 得到了最佳结果。

当然,我运行了每个程序多次以查看时间是否稳定,并且运行时间的差异不只是噪声。

看起来这些都只是微小的优化。

以下是在不使用手动构建的数据结构时,对我来说运行最快的程序:

import qualified Data.IntMap.Strict as M
import Control.Monad
import Data.Array.IO
import Data.Array.Base (unsafeRead, unsafeWrite)

type Number = Int
type Place = IOUArray Number Number
type Map = M.IntMap Place

tupleToRef :: (Number, Number) -> IO (Number, Place)
tupleToRef = traverse (newArray (0,0))

insertRefs :: [(Number, Number)] -> IO [(Number, Place)]
insertRefs = traverse tupleToRef

fromList :: [(Number, Number)] -> IO Map 
fromList xs = M.fromList <$> insertRefs xs

(!) :: Map -> Number -> IO Number
(!) m i = unsafeRead (m M.! i) 0

mupdate :: Map -> Number -> (Number -> Number) -> IO ()
mupdate m i f = do
  let place = m M.! i
  old <- unsafeRead place 0
  let x' = f old
  -- make the application of f strict
  x' `seq` unsafeWrite place 0 x'

r' :: Number -> Number
r'  = floor . sqrt . fromIntegral

vs' :: Integral a => a -> a -> [a]
vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]  

list' :: [Number] -> IO Map
list' vs = fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]

problem10 :: Number -> IO Number
problem10 n = do
      m <- list' vs
      nm <- sieve m 2 r vs
      nm ! n
    where vs = vs' n r
          r  = r' n

sieve :: Map -> Number -> Number -> [Number] -> IO Map
sieve m p r vs | p > r     = return m
               | otherwise = do
                   v1 <- m ! p
                   v2 <- m ! (p - 1)
                   nm <- if v1 > v2 then update m vs p else return m
                   sieve nm (p + 1) r vs

update :: Map -> [Number] -> Number -> IO Map
update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs

decrease :: Number -> Map -> Number -> IO Map
decrease p m k = do
  v <- sumOfSieved m k p
  mupdate m k (subtract v)
  return m

sumOfSieved :: Map -> Number -> Number -> IO Number
sumOfSieved m v p = do
  v1 <- m ! (v `div` p)
  v2 <- m ! (p - 1)
  return $ p * (v1 - v2)

main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9

如果你分析一下,会发现大部分时间都花在了自定义查找函数(!)上,不知道如何进一步改进。尝试使用{-# INLINE (!) #-}内联(!)并没有取得更好的结果;也许ghc已经这样做了。

2
下班后,我会在 Reddit 上发布请求帮助的帖子。如果没有现成的库能够击败我这里的天真方法,那就有点疯狂了。 - typetetris

4

我的这段代码可以在0.3秒内评估出2×10^9的和,并且可以在19.6秒内(如果有足够的RAM)计算出10^12(18435588552550705911377)的和。

import Control.DeepSeq 
import qualified Control.Monad as ControlMonad
import qualified Data.Array as Array
import qualified Data.Array.ST as ArrayST
import qualified Data.Array.Base as ArrayBase

primeLucy :: (Integer -> Integer) -> (Integer -> Integer) -> Integer -> (Integer->Integer)
primeLucy f sf n = g
  where
    r = fromIntegral $ integerSquareRoot n
    ni = fromIntegral n
    loop from to c = let go i = ControlMonad.when (to<=i) (c i >> go (i-1)) in go from

    k = ArrayST.runSTArray $ do
      k <- ArrayST.newListArray (-r,r) $ force $
        [sf (div n (toInteger i)) - sf 1|i<-[r,r-1..1]] ++
        [0] ++
        [sf (toInteger i) - sf 1|i<-[1..r]]
      ControlMonad.forM_ (takeWhile (<=r) primes) $ \p -> do
        l <- ArrayST.readArray k (p-1)
        let q = force $ f (toInteger p)

        let adjust = \i j -> do { v <- ArrayBase.unsafeRead k (i+r); w <- ArrayBase.unsafeRead k (j+r); ArrayBase.unsafeWrite k (i+r) $!! v+q*(l-w) }

        loop (-1)         (-div r p)              $ \i -> adjust i (i*p)
        loop (-div r p-1) (-min r (div ni (p*p))) $ \i -> adjust i (div (-ni) (i*p))
        loop r            (p*p)                   $ \i -> adjust i (div i p)

      return k

    g :: Integer -> Integer
    g m
      | m >= 1 && m <= integerSquareRoot n                       = k Array.! (fromIntegral m)
      | m >= integerSquareRoot n && m <= n && div n (div n m)==m = k Array.! (fromIntegral (negate (div n m)))
      | otherwise = error $ "Function not precalculated for value " ++ show m

primeSum :: Integer -> Integer
primeSum n = (primeLucy id (\m -> div (m*m+m) 2) n) n

如果您的integerSquareRoot函数存在问题(据报道确实存在一些问题),您可以在此处使用floor . sqrt . fromIntegral进行替换。
解释:
正如其名称所暗示的那样,它基于"Lucy Hedgehog"发现的原始作者最终发现的著名方法的概括。
它允许您计算许多形式为sum(p为质数)的总和,而无需枚举所有小于N的质数,并在O(N^0.75)的时间内完成。
它的输入是函数f(即id,如果您想要质数之和),其对于所有整数的总和函数(在这种情况下为前m个整数的和或div (m*m+m) 2)和N。 PrimeLucy返回一个查找函数eq(p为质数),限制为某些n值:values

抱歉只是从我的标准库复制,并没有包含所有的导入。force($!!)来自于Control.DeepSeq,该模块已包含在大多数发行版中。primes只是质数列表,来源随意。我建议使用Hackage上的arithmoi包。 - CarlEdman
@CarlEdman 我已经填写了空白并使用n = 2⋅10^9测试了您的代码,运行时间为0.3秒,令人印象深刻。请问您能否描述一下您的方法? - Adam Stelmaszczyk
我觉得我在之前的问题中没有表达清楚,抱歉。我理解Lucy的算法是如何工作的。我想要了解的是为什么你的Haskell代码运行速度这么快,它是否采用了Shersh的“我的代码+我的映射”方法,以达到0.2秒的速度?例如,“我使用了2个可变数组,在第一个数组中存储x...这与Shersh相同/不同的地方在于y的差异”。 - Adam Stelmaszczyk
1
除了通常的试错手动调整之外,性能背后没有什么魔法,算法只是一个开始。它可能始于像Shersh这样的东西,然后逐位转换。有助于的事情包括:使用单个可变数组;通过按照谨慎的顺序执行更新,而不是复制到新数组上,在每次迭代中仅修改数组;使用unsafeRead/Write;自定义内部循环而不是forM_;通过'$!!'和'force'尽早强制结果。 - CarlEdman
@AdamStelmaszczyk,还要感谢您修复我的TeX!我总是忘记哪个论坛允许哪种语法。 - CarlEdman
显示剩余2条评论

2

试一下这个,然后告诉我速度有多快:

-- sum of primes

import Control.Monad (forM_, when)
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed

sieve :: Int -> UArray Int Bool
sieve n = runSTUArray $ do
    let m = (n-1) `div` 2
        r = floor . sqrt $ fromIntegral n
    bits <- newArray (0, m-1) True
    forM_ [0 .. r `div` 2 - 1] $ \i -> do
        isPrime <- readArray bits i
        when isPrime $ do
            let a = 2*i*i + 6*i + 3
                b = 2*i*i + 8*i + 6
            forM_ [a, b .. (m-1)] $ \j -> do
                writeArray bits j False
    return bits

primes :: Int -> [Int]
primes n = 2 : [2*i+3 | (i, True) <- assocs $ sieve n]

main = do
    print $ sum $ primes 1000000

你可以在 ideone 上运行它。我的算法是埃拉托色尼筛法,对于小的n来说应该非常快。当n = 2,000,000,000时,数组大小可能会成为一个问题,在这种情况下,您需要使用分段筛法。有关埃拉托色尼筛法的更多信息,请参见我的博客。有关分段筛法的信息,请参见此答案(不幸的是不是使用Haskell)。

我尝试了n = 2⋅10^9的值来与其他实现进行比较,结果为71秒。据我所知,分段筛法的时间复杂度与非分段筛法相同。我很欣赏在实际中,参考局部性可以提高分段版本的速度,但我的感觉是它会比Lucy_Hedgehog提出的_sub-linear_算法慢。我的问题是关于优化Haskell代码实现Lucy算法的,我现在澄清了这个问题,我可以看到它可能有点令人困惑,抱歉。不错的博客,我会阅读它,我特别喜欢你的文章。 - Adam Stelmaszczyk
分段筛法具有相同的大O时间复杂度,但由于引用局部性而会更快。感谢您对我的博客的好评。 - user448810

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接