在Haskell中高效计算列表的平均值

11

我设计了一个计算列表平均值的函数。虽然它能正常工作,但我认为它可能不是最好的解决方案,因为它需要使用两个函数而不是一个。是否有可能只使用一个递归函数完成这项工作?

calcMeanList (x:xs) = doCalcMeanList (x:xs) 0 0

doCalcMeanList (x:xs) sum length =  doCalcMeanList xs (sum+x) (length+1)
doCalcMeanList [] sum length = sum/length

1
记住,任何简单除法的解决方案都会对空列表产生NaN。这并不一定是问题,只是我认为值得注意的事情。 - Chuck
可能是Laziness and tail recursion in Haskell, why is this crashing?的重复问题。 - Don Stewart
惰性求值和尾递归在Haskell中的应用:为什么这会导致程序崩溃? - Don Stewart
很抱歉重复提问,下次我会更仔细搜索。 - snowmantw
@snowmantw:你不可能知道,那个问题的标题中没有任何暗示它是一个有关计算平均值的问题。@Don Stewart:我认为这不是重复问题。代码非常相似,但是关于代码的问题是完全不同的。 - Owen S.
顺便提一下,sumlength都是Prelude中的函数,所以你可能想使用其他变量名来避免混淆。 - jkramer
6个回答

11

您的解决方案很好,使用两个函数并不比使用一个更糟糕。但是,您可以将尾递归函数放在 where 子句中。

但是,如果您想要用一行代码实现:

calcMeanList = uncurry (/) . foldr (\e (s,c) -> (e+s,c+1)) (0,0)

1
为什么使用 foldr 而不是 foldl?我认为后者更适合。 - Axman6
可以在这里使用foldl、foldl'或foldr,因为你必须遍历整个列表(我选择了其中一个)……我认为如果性能很重要,可以在这里使用foldl'。 - Kru

10

关于您能做的最好的是这个版本:

import qualified Data.Vector.Unboxed as U

data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double

mean :: U.Vector Double -> Double
mean xs = s / fromIntegral n
  where
    Pair n s       = U.foldl' k (Pair 0 0) xs
    k (Pair n s) x = Pair (n+1) (s+x)

main = print (mean $ U.enumFromN 1 (10^7))

它会融合到Core中的最佳循环(你可以编写的最好的Haskell代码):

main_$s$wfoldlM'_loop :: Int#
                              -> Double#
                              -> Double#
                              -> Int#
                              -> (# Int#, Double# #)    
main_$s$wfoldlM'_loop =
  \ (sc_s1nH :: Int#)
    (sc1_s1nI :: Double#)
    (sc2_s1nJ :: Double#)
    (sc3_s1nK :: Int#) ->
    case ># sc_s1nH 0 of _ {
      False -> (# sc3_s1nK, sc2_s1nJ #);
      True ->
        main_$s$wfoldlM'_loop
          (-# sc_s1nH 1)
          (+## sc1_s1nI 1.0)
          (+## sc2_s1nJ sc1_s1nI)
          (+# sc3_s1nK 1)
    }

以下是相应的汇编代码:

Main_mainzuzdszdwfoldlMzqzuloop_info:
.Lc1pN:
        testq %r14,%r14
        jg .Lc1pQ
        movq %rsi,%rbx
        movsd %xmm6,%xmm5
        jmp *(%rbp)
.Lc1pQ:
        leaq 1(%rsi),%rax
        movsd %xmm6,%xmm0
        addsd %xmm5,%xmm0
        movsd %xmm5,%xmm7
        addsd .Ln1pS(%rip),%xmm7
        decq %r14
        movsd %xmm7,%xmm5
        movsd %xmm0,%xmm6
        movq %rax,%rsi
        jmp Main_mainzuzdszdwfoldlMzqzuloop_info

基于Data.Vector。例如:

$ ghc -Odph --make A.hs -fforce-recomp
[1 of 1] Compiling Main             ( A.hs, A.o )
Linking A ...
$ time ./A
5000000.5
./A  0.04s user 0.00s system 93% cpu 0.046 total

请参考统计学包中的高效实现。


5
当我看到你的问题时,我立刻想到了“你想在那里折叠一个fold!”
而且,在StackOverflow上以前已经有人问过类似的问题a similar question,并且this answer有一个非常高效的解决方案,你可以在像GHCi这样的交互式环境中进行测试。
import Data.List

let avg l = let (t,n) = foldl' (\(b,c) a -> (a+b,c+1)) (0,0) l 
            in realToFrac(t)/realToFrac(n)

avg ([1,2,3,4]::[Int])
2.5
avg ([1,2,3,4]::[Double])
2.5

3

对于那些想知道glowcoder和Assaf方法在Haskell中是什么样子的人,这里有一个翻译:

avg [] = 0
avg x@(t:ts) = let xlen = toRational $ length x
                   tslen = toRational $ length ts
                   prevAvg = avg ts
               in (toRational t) / xlen + prevAvg * tslen / xlen

这种方式确保每个步骤都正确计算了“迄今为止的平均值”,但代价是大量冗余的长度乘除和每个步骤中非常低效的长度计算。没有经验的Haskeller不会这样写。
一种略微更好的方法是:
avg2 [] = 0
avg2 x = fst $ avg_ x
    where 
      avg_ [] = (toRational 0, toRational 0)
      avg_ (t:ts) = let
           (prevAvg, prevLen) = avg_ ts
           curLen = prevLen + 1
           curAvg = (toRational t) / curLen + prevAvg * prevLen / curLen
        in (curAvg, curLen)

这样可以避免重复的长度计算。但是这需要一个辅助函数,这恰恰是原始帖子作者试图避免的。而且它仍然需要大量抵消长度项。
为了避免长度的相互抵消,我们可以在最后进行总和和长度的累加,再进行除法运算:
avg3 [] = 0
avg3 x = (toRational total) / (toRational len)
    where 
      (total, len) = avg_ x
      avg_ [] = (0, 0)
      avg_ (t:ts) = let 
          (prevSum, prevLen) = avg_ ts
       in (prevSum + t, prevLen + 1)

这个可以使用foldr更简洁地表达:

avg4 [] = 0
avg4 x = (toRational total) / (toRational len)
    where
      (total, len) = foldr avg_ (0,0) x
      avg_ t (prevSum, prevLen) = (prevSum + t, prevLen + 1)

根据之前的帖子,这可以进一步简化。

在这里,折叠(Fold)确实是最佳选择。


3
虽然我不确定在一个函数中写是否最好,但可以按以下方式完成:
如果您预先知道长度(这里称为“n”),那很容易-您可以计算每个值“增加”平均值的量;即值/长度。由于avg(x1,x2,x3)= sum(x1,x2,x3)/length =(x1 + x2 + x3)/ 3 = x1 / 3 + x2 / 3 + x2 / 3
如果您不知道长度,则有点棘手:
假设我们使用列表{x1,x2,x3},而不知道其n = 3。
第一次迭代将只是x1(因为我们假设它只有n = 1) 第二次迭代会添加x2 / 2,并将现有平均值除以2,因此现在我们有x1 / 2 + x2 / 2
第三次迭代后,我们有n = 3,并且我们想要x1 / 3 + x2 / 3 + x3 / 3,但是我们有x1 / 2 + x2 / 2
因此,我们需要乘以(n-1)并除以n,以获得x1 / 3 + x2 / 3,然后我们只需将当前值(x3)除以n,以得到x1 / 3 + x2 / 3 + x3 / 3
通常:
给定n-1项的平均值(算术平均值-avg),如果要将一个项(newval)添加到平均值中,则您的方程将为:
avg *(n-1)/ n + newval / n。可以使用归纳法数学证明方程。
希望这会有所帮助。
*请注意,此解决方案的效率低于仅对变量求和并按总长度进行除法的效率,就像您在示例中所做的那样。

1

针对Don在2010年的回复,我们在GHC 8.0.2上可以做得更好。首先让我们尝试他的版本。

module Main (main) where

import System.CPUTime.Rdtsc (rdtsc)
import Text.Printf (printf)
import qualified Data.Vector.Unboxed as U

data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double

mean' :: U.Vector Double -> Double
mean' xs = s / fromIntegral n
  where
    Pair n s       = U.foldl' k (Pair 0 0) xs
    k (Pair n s) x = Pair (n+1) (s+x)

main :: IO ()
main = do
  s <- rdtsc
  let r = mean' (U.enumFromN 1 30000000)
  e <- seq r rdtsc
  print (e - s, r)

这个给了我们。
[nix-shell:/tmp]$ ghc -fforce-recomp -O2 MeanD.hs -o MeanD && ./MeanD +RTS -s
[1 of 1] Compiling Main             ( MeanD.hs, MeanD.o )
Linking MeanD ...
(372877482,1.50000005e7)
     240,104,176 bytes allocated in the heap
           6,832 bytes copied during GC
          44,384 bytes maximum residency (1 sample(s))
          25,248 bytes maximum slop
             230 MB total memory in use (0 MB lost due to fragmentation)

                                     Tot time (elapsed)  Avg pause  Max pause
  Gen  0         1 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s
  Gen  1         1 colls,     0 par    0.006s   0.006s     0.0062s    0.0062s

  INIT    time    0.000s  (  0.000s elapsed)
  MUT     time    0.087s  (  0.087s elapsed)
  GC      time    0.006s  (  0.006s elapsed)
  EXIT    time    0.006s  (  0.006s elapsed)
  Total   time    0.100s  (  0.099s elapsed)

  %GC     time       6.2%  (6.2% elapsed)

  Alloc rate    2,761,447,559 bytes per MUT second

  Productivity  93.8% of total user, 93.8% of total elapsed

然而,代码很简单:理想情况下不需要 vector,只需将列表生成内联即可实现最优代码。幸运的是,GHC 可以为我们完成这个任务[0]。
module Main (main) where

import System.CPUTime.Rdtsc (rdtsc)
import Text.Printf (printf)
import Data.List (foldl')

data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double

mean' :: [Double] -> Double
mean' xs = v / fromIntegral l
  where
    Pair l v = foldl' f (Pair 0 0) xs
    f (Pair l' v') x = Pair (l' + 1) (v' + x)

main :: IO ()
main = do
  s <- rdtsc
  let r = mean' $ fromIntegral <$> [1 :: Int .. 30000000]
      -- This is slow!
      -- r = mean' [1 .. 30000000]
  e <- seq r rdtsc
  print (e - s, r)

这给我们带来了:
[nix-shell:/tmp]$ ghc -fforce-recomp -O2 MeanD.hs -o MeanD && ./MeanD +RTS -s
[1 of 1] Compiling Main             ( MeanD.hs, MeanD.o )
Linking MeanD ...
(128434754,1.50000005e7)
         104,064 bytes allocated in the heap
           3,480 bytes copied during GC
          44,384 bytes maximum residency (1 sample(s))
          17,056 bytes maximum slop
               1 MB total memory in use (0 MB lost due to fragmentation)

                                     Tot time (elapsed)  Avg pause  Max pause
  Gen  0         0 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s
  Gen  1         1 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s

  INIT    time    0.000s  (  0.000s elapsed)
  MUT     time    0.032s  (  0.032s elapsed)
  GC      time    0.000s  (  0.000s elapsed)
  EXIT    time    0.000s  (  0.000s elapsed)
  Total   time    0.033s  (  0.032s elapsed)

  %GC     time       0.1%  (0.1% elapsed)

  Alloc rate    3,244,739 bytes per MUT second

  Productivity  99.8% of total user, 99.8% of total elapsed

[0]: 注意我必须映射fromIntegral:如果没有这个,GHC无法消除[Double],解决方案会慢得多。这有点令人难过:我不明白为什么GHC在没有这个的情况下无法内联/决定不需要。如果您确实拥有分数集合,则此 hack 对您无效,可能仍需要使用 vector。

另外一个有趣的事情是:如果我们在 [Int] 上工作并使用 -fllvm,那么在这种情况下我们能够获得几乎恒定时间的答案。 - Mateusz Kowalczyk

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接