为什么这个简单的Haskell算法如此缓慢?

19

剧透警告:这涉及到来自欧拉计划的问题14

下面的代码需要大约15秒才能运行。我有一个非递归的Java解决方案,只需1秒即可运行。我认为我应该能够将这段代码优化到更接近那个速度。

import Data.List

collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

main = do
  print ((foldl1' max) . map (collatz 1) $ [1..1000000])
我使用了+RHS -p进行了性能分析并发现分配的内存很大,随着输入量的增加而增长。对于n = 100,000,分配了1GB的内存(!),对于n = 1,000,000,分配了13GB(!!)的内存。

然而,-sstderr显示尽管分配了大量字节,总内存使用量为1MB,生产率达到95%以上,因此可能13GB是误导性的。

我可以想到一些可能性:

  1. 某些部分不够严格。我已经发现了foldl1',但也许我需要做更多的工作?是否有可能将collatz标记为严格(这种做法是否合理?

  2. collatz没有实现尾递归优化。我认为它应该是这样的,但不知道如何确认。

  3. 编译器没有执行我认为应该执行的一些优化——例如,collatz的只有两个结果需要在任何时候都在内存中(最大值和当前值)

有什么建议吗?

这与Why is this Haskell expression so slow?几乎是重复的,尽管我会指出,快速的Java解决方案不需要执行任何记忆化。有没有办法在不必诉诸记忆化的情况下加快这个程序的运行速度呢?

以下是我的性能分析输出供参考:

  Wed Dec 28 09:33 2011 Time and Allocation Profiling Report  (Final)

     scratch +RTS -p -hc -RTS

  total time  =        5.12 secs   (256 ticks @ 20 ms)
  total alloc = 13,229,705,716 bytes  (excludes profiling overheads)

COST CENTRE                    MODULE               %time %alloc

collatz                        Main                  99.6   99.4


                                                                                               individual    inherited
COST CENTRE              MODULE                                               no.    entries  %time %alloc   %time %alloc

MAIN                     MAIN                                                   1           0   0.0    0.0   100.0  100.0
 CAF                     Main                                                 208          10   0.0    0.0   100.0  100.0
  collatz                Main                                                 215           1   0.0    0.0     0.0    0.0
  main                   Main                                                 214           1   0.4    0.6   100.0  100.0
   collatz               Main                                                 216           0  99.6   99.4    99.6   99.4
 CAF                     GHC.IO.Handle.FD                                     145           2   0.0    0.0     0.0    0.0
 CAF                     System.Posix.Internals                               144           1   0.0    0.0     0.0    0.0
 CAF                     GHC.Conc                                             128           1   0.0    0.0     0.0    0.0
 CAF                     GHC.IO.Handle.Internals                              119           1   0.0    0.0     0.0    0.0
 CAF                     GHC.IO.Encoding.Iconv                                113           5   0.0    0.0     0.0    0.0

而且,-sstderr:

./scratch +RTS -sstderr 
525
  21,085,474,908 bytes allocated in the heap
      87,799,504 bytes copied during GC
           9,420 bytes maximum residency (1 sample(s))          
          12,824 bytes maximum slop               
               1 MB total memory in use (0 MB lost due to fragmentation)  

  Generation 0: 40219 collections,     0 parallel,  0.40s,  0.51s elapsed
  Generation 1:     1 collections,     0 parallel,  0.00s,  0.00s elapsed

  INIT  time    0.00s  (  0.00s elapsed)
  MUT   time   35.38s  ( 36.37s elapsed)
  GC    time    0.40s  (  0.51s elapsed)
  RP    time    0.00s  (  0.00s elapsed)  PROF  time    0.00s  (  0.00s elapsed)
  EXIT  time    0.00s  (  0.00s elapsed)
  Total time   35.79s  ( 36.88s elapsed)  %GC time       1.1%  (1.4% elapsed)  Alloc rate    595,897,095 bytes per MUT second

  Productivity  98.9% of total user, 95.9% of total elapsed

以下是Java解决方案(不是我写的,来自于Project Euler论坛并移除了记忆化):

public class Collatz {
  public int getChainLength( int n )
  {
    long num = n;
    int count = 1;
    while( num > 1 )
    {
      num = ( num%2 == 0 ) ? num >> 1 : 3*num+1;
      count++;
    }
    return count;
  }

  public static void main(String[] args) {
    Collatz obj = new Collatz();
    long tic = System.currentTimeMillis();
    int max = 0, len = 0, index = 0;
    for( int i = 3; i < 1000000; i++ )
    {
      len = obj.getChainLength(i);
      if( len > max )
      {
        max = len;
        index = i;
      }
    }
    long toc = System.currentTimeMillis();
    System.out.println(toc-tic);
    System.out.println( "Index: " + index + ", length = " + max );
  }
}

令人惊讶的是,GHC没有像任何自重的C编译器一样将(quot n 2)优化为(rshift n 1)。这是有原因的吗? - user1120317
@solrize:这也让我感到惊讶。 - ehird
3个回答

21

起初,我认为你应该在collatz中的a前加上一个感叹号:

collatz !a 1  = a
collatz !a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)
(为了让这个方法生效,您需要在源文件顶部添加 {-# LANGUAGE BangPatterns #-}。)
我的想法如下:问题在于您正在第一个参数中累积一个巨大的“thunk”,它从1开始,然后变成1 + 1,然后变成(1 + 1) + 1,...所有这些都没有被强制执行。该bang模式会强制在调用时强制执行collatz的第一个参数,因此它从1开始,然后变成2,依此类推,而不会建立一个大的未求值的thunk: 它只保持为一个整数。
请注意,bang模式只是使用seq的速记方式;在本例中,我们可以将collatz重写为以下形式:
collatz a _ | seq a False = undefined
collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)
这里的诀窍是强制将a放在保护条件中,这样它总是评估为False(因此主体不相关)。然后继续使用下一个案例进行评估,a已被评估。然而,一种更清晰的方法是使用bang模式。
不幸的是,使用-O2编译后,这并没有比原始版本运行得更快!我们还能尝试什么?唯一的办法就是假定这两个数字永远不会溢出机器大小的整数,并给collatz添加这个类型注释:
collatz :: Int -> Int -> Int

即使它们不是性能问题的根源,我们仍应避免构建thunk。因此,我们将保留Bang模式,这将在我的(慢)计算机上将时间缩短到8.5秒。

下一步是尝试将其接近Java解决方案。首先要意识到的是,在Haskell中,div对于负整数的处理方式更符合数学规则,但比“常规”的C除法慢,在Haskell中称为quot。将div替换为quot将运行时间降至5.2秒,将x `quot` 2替换为x `shiftR` 1(导入Data.Bits),以与Java解决方案相匹配,则将其降至4.9秒。

目前我已经做到了最低限度,但我认为这是一个非常好的结果;由于您的计算机比我的快,希望它能更接近Java解决方案。

以下是最终代码(我进行了一些清理):

{-# LANGUAGE BangPatterns #-}

import Data.Bits
import Data.List

collatz :: Int -> Int
collatz = collatz' 1
  where collatz' :: Int -> Int -> Int
        collatz' !a 1 = a
        collatz' !a x
          | even x    = collatz' (a + 1) (x `shiftR` 1)
          | otherwise = collatz' (a + 1) (3 * x + 1)

main :: IO ()
main = print . foldl1' max . map collatz $ [1..1000000]

通过查看这个程序的 GHC Core(使用 ghc-core),我认为这已经是最好的了;collatz 循环使用了非装箱整数,程序的其余部分看起来也不错。我能想到的唯一改进就是消除 map collatz [1..1000000] 迭代中的装箱。

顺便说一下,不必担心 "total alloc" 数字;它是程序生命周期内分配的总内存,即使 GC 回收了该内存,它也永远不会减少。多个 TB 的数字很常见。


谢谢,这真的很有帮助。我不知道 -O2,这使得巨大的差异(运行时间降至5秒)。已将Java解决方案添加到问题中。 - Xavier Shay
哦,我以为你已经在使用“-O2”了,因为带有惊叹号模式的修订程序在我的机器上运行了16秒 :) 我会看一下你的Java解决方案。 - ehird
顺便说一下,Java程序实际上是从x=3开始的,但性能影响可以忽略不计,而且感觉像作弊,所以我没有让Haskell程序也这样做 :) - ehird
shiftR 把运行时间降到了1.5秒。我对此感到满意!再次感谢。 - Xavier Shay
没问题 :) 我刚刚添加了一段关于 GHC 生成程序的内部核心的段落;看起来它已经趋于完美了。 - ehird
1
原来,在64位机器上只能使用Int(或者实际上是Word)。在32位机器上,该序列将会溢出。只是说一下,为了未来可能在这里阅读的读者们的利益。 :) - Will Ness

2

通过使用堆栈而不是列表和bang模式,您仍然可以获得相同的性能。

import Data.List
import Data.Bits

coll :: Int -> Int
coll 0 = 0
coll 1 = 1
coll 2 = 2
coll n =
  let a = coll (n - 1)
      collatz a 1 = a
      collatz a x
        | even x    = collatz (a + 1) (x `shiftR` 1)
        | otherwise = collatz (a + 1) (3 * x + 1)
  in max a (collatz 1 n)


main = do
  print $ coll 100000

这种方法的一个问题是,对于像1_000_000这样的大输入,您将不得不增加堆栈的大小。

更新:

这里是一种尾递归版本,不会遭受堆栈溢出问题。

import Data.Word
collatz :: Word -> Word -> (Word, Word)
collatz a x
  | x == 1    = (a,x)
  | even x    = collatz (a + 1) (x `quot` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

coll :: Word -> Word
coll n = collTail 0 n
  where
    collTail m 1 = m
    collTail m n = collTail (max (fst $ collatz 1 n) m) (n-1)

注意使用Word而不是Int,这会影响性能。如果您仍想使用感叹号模式,那么性能将近乎翻倍。


0

在这个问题中,我发现有一件事情让我感到惊讶的不同。我坚持使用直接的递归关系,而不是将计数与之折叠在一起。重写:

collatz n = if even n then n `div` 2 else 3 * n + 1

作为

collatz n = case n `divMod` 2 of
            (n', 0) -> n'
            _       -> 3 * n + 1

在一台2.8 GHz Athlon II X4 430 CPU的系统上,我的程序运行时间减少了1.2秒。使用divMod后,我的初始更快版本为2.3秒:

{-# LANGUAGE BangPatterns #-}

import Data.List
import Data.Ord

collatzChainLen :: Int -> Int
collatzChainLen n = collatzChainLen' n 1
    where collatzChainLen' n !l
            | n == 1    = l
            | otherwise = collatzChainLen' (collatz n) (l + 1)

collatz:: Int -> Int
collatz n = case n `divMod` 2 of
                 (n', 0) -> n'
                 _       -> 3 * n + 1

pairMap :: (a -> b) -> [a] -> [(a, b)]
pairMap f xs = [(x, f x) | x <- xs]

main :: IO ()
main = print $ fst (maximumBy (comparing snd) (pairMap collatzChainLen [1..999999]))

一个可能更符合习惯的 Haskell 版本运行大约需要 9.7 秒(使用 divMod 后为 8.5 秒);除了这一点,它与原版完全相同。

collatzChainLen :: Int -> Int
collatzChainLen n = 1 + (length . takeWhile (/= 1) . (iterate collatz)) n

使用Data.List.Stream应该允许流融合,使得这个版本更像明确累加的版本,但我找不到一个具有Data.List.Stream的Ubuntu libghc*软件包,所以我还不能验证它。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接