我设计了一个计算列表平均值的函数。虽然它能正常工作,但我认为它可能不是最好的解决方案,因为它需要使用两个函数而不是一个。是否有可能只使用一个递归函数完成这项工作?
calcMeanList (x:xs) = doCalcMeanList (x:xs) 0 0
doCalcMeanList (x:xs) sum length = doCalcMeanList xs (sum+x) (length+1)
doCalcMeanList [] sum length = sum/length
我设计了一个计算列表平均值的函数。虽然它能正常工作,但我认为它可能不是最好的解决方案,因为它需要使用两个函数而不是一个。是否有可能只使用一个递归函数完成这项工作?
calcMeanList (x:xs) = doCalcMeanList (x:xs) 0 0
doCalcMeanList (x:xs) sum length = doCalcMeanList xs (sum+x) (length+1)
doCalcMeanList [] sum length = sum/length
您的解决方案很好,使用两个函数并不比使用一个更糟糕。但是,您可以将尾递归函数放在 where
子句中。
但是,如果您想要用一行代码实现:
calcMeanList = uncurry (/) . foldr (\e (s,c) -> (e+s,c+1)) (0,0)
关于您能做的最好的是这个版本:
import qualified Data.Vector.Unboxed as U
data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double
mean :: U.Vector Double -> Double
mean xs = s / fromIntegral n
where
Pair n s = U.foldl' k (Pair 0 0) xs
k (Pair n s) x = Pair (n+1) (s+x)
main = print (mean $ U.enumFromN 1 (10^7))
它会融合到Core中的最佳循环(你可以编写的最好的Haskell代码):
main_$s$wfoldlM'_loop :: Int#
-> Double#
-> Double#
-> Int#
-> (# Int#, Double# #)
main_$s$wfoldlM'_loop =
\ (sc_s1nH :: Int#)
(sc1_s1nI :: Double#)
(sc2_s1nJ :: Double#)
(sc3_s1nK :: Int#) ->
case ># sc_s1nH 0 of _ {
False -> (# sc3_s1nK, sc2_s1nJ #);
True ->
main_$s$wfoldlM'_loop
(-# sc_s1nH 1)
(+## sc1_s1nI 1.0)
(+## sc2_s1nJ sc1_s1nI)
(+# sc3_s1nK 1)
}
以下是相应的汇编代码:
Main_mainzuzdszdwfoldlMzqzuloop_info:
.Lc1pN:
testq %r14,%r14
jg .Lc1pQ
movq %rsi,%rbx
movsd %xmm6,%xmm5
jmp *(%rbp)
.Lc1pQ:
leaq 1(%rsi),%rax
movsd %xmm6,%xmm0
addsd %xmm5,%xmm0
movsd %xmm5,%xmm7
addsd .Ln1pS(%rip),%xmm7
decq %r14
movsd %xmm7,%xmm5
movsd %xmm0,%xmm6
movq %rax,%rsi
jmp Main_mainzuzdszdwfoldlMzqzuloop_info
基于Data.Vector。例如:
$ ghc -Odph --make A.hs -fforce-recomp
[1 of 1] Compiling Main ( A.hs, A.o )
Linking A ...
$ time ./A
5000000.5
./A 0.04s user 0.00s system 93% cpu 0.046 total
请参考统计学包中的高效实现。
import Data.List
let avg l = let (t,n) = foldl' (\(b,c) a -> (a+b,c+1)) (0,0) l
in realToFrac(t)/realToFrac(n)
avg ([1,2,3,4]::[Int])
2.5
avg ([1,2,3,4]::[Double])
2.5
对于那些想知道glowcoder和Assaf方法在Haskell中是什么样子的人,这里有一个翻译:
avg [] = 0
avg x@(t:ts) = let xlen = toRational $ length x
tslen = toRational $ length ts
prevAvg = avg ts
in (toRational t) / xlen + prevAvg * tslen / xlen
avg2 [] = 0
avg2 x = fst $ avg_ x
where
avg_ [] = (toRational 0, toRational 0)
avg_ (t:ts) = let
(prevAvg, prevLen) = avg_ ts
curLen = prevLen + 1
curAvg = (toRational t) / curLen + prevAvg * prevLen / curLen
in (curAvg, curLen)
avg3 [] = 0
avg3 x = (toRational total) / (toRational len)
where
(total, len) = avg_ x
avg_ [] = (0, 0)
avg_ (t:ts) = let
(prevSum, prevLen) = avg_ ts
in (prevSum + t, prevLen + 1)
这个可以使用foldr更简洁地表达:
avg4 [] = 0
avg4 x = (toRational total) / (toRational len)
where
(total, len) = foldr avg_ (0,0) x
avg_ t (prevSum, prevLen) = (prevSum + t, prevLen + 1)
根据之前的帖子,这可以进一步简化。
在这里,折叠(Fold)确实是最佳选择。
针对Don在2010年的回复,我们在GHC 8.0.2上可以做得更好。首先让我们尝试他的版本。
module Main (main) where
import System.CPUTime.Rdtsc (rdtsc)
import Text.Printf (printf)
import qualified Data.Vector.Unboxed as U
data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double
mean' :: U.Vector Double -> Double
mean' xs = s / fromIntegral n
where
Pair n s = U.foldl' k (Pair 0 0) xs
k (Pair n s) x = Pair (n+1) (s+x)
main :: IO ()
main = do
s <- rdtsc
let r = mean' (U.enumFromN 1 30000000)
e <- seq r rdtsc
print (e - s, r)
[nix-shell:/tmp]$ ghc -fforce-recomp -O2 MeanD.hs -o MeanD && ./MeanD +RTS -s
[1 of 1] Compiling Main ( MeanD.hs, MeanD.o )
Linking MeanD ...
(372877482,1.50000005e7)
240,104,176 bytes allocated in the heap
6,832 bytes copied during GC
44,384 bytes maximum residency (1 sample(s))
25,248 bytes maximum slop
230 MB total memory in use (0 MB lost due to fragmentation)
Tot time (elapsed) Avg pause Max pause
Gen 0 1 colls, 0 par 0.000s 0.000s 0.0000s 0.0000s
Gen 1 1 colls, 0 par 0.006s 0.006s 0.0062s 0.0062s
INIT time 0.000s ( 0.000s elapsed)
MUT time 0.087s ( 0.087s elapsed)
GC time 0.006s ( 0.006s elapsed)
EXIT time 0.006s ( 0.006s elapsed)
Total time 0.100s ( 0.099s elapsed)
%GC time 6.2% (6.2% elapsed)
Alloc rate 2,761,447,559 bytes per MUT second
Productivity 93.8% of total user, 93.8% of total elapsed
module Main (main) where
import System.CPUTime.Rdtsc (rdtsc)
import Text.Printf (printf)
import Data.List (foldl')
data Pair = Pair {-# UNPACK #-}!Int {-# UNPACK #-}!Double
mean' :: [Double] -> Double
mean' xs = v / fromIntegral l
where
Pair l v = foldl' f (Pair 0 0) xs
f (Pair l' v') x = Pair (l' + 1) (v' + x)
main :: IO ()
main = do
s <- rdtsc
let r = mean' $ fromIntegral <$> [1 :: Int .. 30000000]
-- This is slow!
-- r = mean' [1 .. 30000000]
e <- seq r rdtsc
print (e - s, r)
[nix-shell:/tmp]$ ghc -fforce-recomp -O2 MeanD.hs -o MeanD && ./MeanD +RTS -s
[1 of 1] Compiling Main ( MeanD.hs, MeanD.o )
Linking MeanD ...
(128434754,1.50000005e7)
104,064 bytes allocated in the heap
3,480 bytes copied during GC
44,384 bytes maximum residency (1 sample(s))
17,056 bytes maximum slop
1 MB total memory in use (0 MB lost due to fragmentation)
Tot time (elapsed) Avg pause Max pause
Gen 0 0 colls, 0 par 0.000s 0.000s 0.0000s 0.0000s
Gen 1 1 colls, 0 par 0.000s 0.000s 0.0000s 0.0000s
INIT time 0.000s ( 0.000s elapsed)
MUT time 0.032s ( 0.032s elapsed)
GC time 0.000s ( 0.000s elapsed)
EXIT time 0.000s ( 0.000s elapsed)
Total time 0.033s ( 0.032s elapsed)
%GC time 0.1% (0.1% elapsed)
Alloc rate 3,244,739 bytes per MUT second
Productivity 99.8% of total user, 99.8% of total elapsed
fromIntegral
:如果没有这个,GHC无法消除[Double]
,解决方案会慢得多。这有点令人难过:我不明白为什么GHC在没有这个的情况下无法内联/决定不需要。如果您确实拥有分数集合,则此 hack 对您无效,可能仍需要使用 vector。[Int]
上工作并使用 -fllvm
,那么在这种情况下我们能够获得几乎恒定时间的答案。 - Mateusz Kowalczyk
sum
和length
都是Prelude中的函数,所以你可能想使用其他变量名来避免混淆。 - jkramer