优化Haskell代码

16
我正在尝试学习Haskell,在Reddit上看了一篇关于Markov文本链的文章后,我决定先在Python中实现Markov文本生成,然后再用Haskell实现。但是我注意到我的Python实现比Haskell版本快得多,即使Haskell已编译为本机代码。我想知道应该怎么做可以让Haskell代码运行得更快,目前我认为它运行得如此之慢是因为使用了Data.Map而不是哈希映射表,但我不确定。
我会发布Python代码和Haskell代码。使用相同的数据,Python需要约3秒钟,而Haskell则需要接近16秒钟。
需要指出的是,我将接受任何建设性的批评 :)。
import random
import re
import cPickle
class Markov:
    def __init__(self, filenames):
        self.filenames = filenames
        self.cache = self.train(self.readfiles())
        picklefd = open("dump", "w")
        cPickle.dump(self.cache, picklefd)
        picklefd.close()

    def train(self, text):
        splitted = re.findall(r"(\w+|[.!?',])", text)
        print "Total of %d splitted words" % (len(splitted))
        cache = {}
        for i in xrange(len(splitted)-2):
            pair = (splitted[i], splitted[i+1])
            followup = splitted[i+2]
            if pair in cache:
                if followup not in cache[pair]:
                    cache[pair][followup] = 1
                else:
                    cache[pair][followup] += 1
            else:
                cache[pair] = {followup: 1}
        return cache

    def readfiles(self):
        data = ""
        for filename in self.filenames:
            fd = open(filename)
            data += fd.read()
            fd.close()
        return data

    def concat(self, words):
        sentence = ""
        for word in words:
            if word in "'\",?!:;.":
                sentence = sentence[0:-1] + word + " "
            else:
                sentence += word + " "
        return sentence

    def pickword(self, words):
        temp = [(k, words[k]) for k in words]
        results = []
        for (word, n) in temp:
            results.append(word)
            if n > 1:
                for i in xrange(n-1):
                    results.append(word)
        return random.choice(results)

    def gentext(self, words):
        allwords = [k for k in self.cache]
        (first, second) = random.choice(filter(lambda (a,b): a.istitle(), [k for k in self.cache]))
        sentence = [first, second]
        while len(sentence) < words or sentence[-1] is not ".":
            current = (sentence[-2], sentence[-1])
            if current in self.cache:
                followup = self.pickword(self.cache[current])
                sentence.append(followup)
            else:
                print "Wasn't able to. Breaking"
                break
        print self.concat(sentence)

Markov(["76.txt"])

--

module Markov
( train
, fox
) where

import Debug.Trace
import qualified Data.Map as M
import qualified System.Random as R
import qualified Data.ByteString.Char8 as B


type Database = M.Map (B.ByteString, B.ByteString) (M.Map B.ByteString Int)

train :: [B.ByteString] -> Database
train (x:y:[]) = M.empty
train (x:y:z:xs) = 
     let l = train (y:z:xs)
     in M.insertWith' (\new old -> M.insertWith' (+) z 1 old) (x, y) (M.singleton z 1) `seq` l

main = do
  contents <- B.readFile "76.txt"
  print $ train $ B.words contents

fox="The quick brown fox jumps over the brown fox who is slow jumps over the brown fox who is dead."

1
有趣,也在寻找答案。16秒与3秒的差距确实很大。 - wvd
Python代码的缩进似乎出现了混乱,顺便说一下... - C. A. McCann
1
我认为你的Haskell代码没有达到你想要的效果。如果你检查输出,你会发现在M.Map String Int映射中没有大于2的值。你是指n + o或者o + 1而不是n + 1吗? - Travis Brown
@Travis,你说得完全正确,但应该在编辑版本中修复。 - Masse
2
您在以“in M.insertWith'”开头的那行代码中使用seq很可疑。您正在构建一个大表达式并对其进行求值,然后丢弃结果并返回其他内容。您是否意味着切换参数,即l seq M.insertWith ...? - luqui
6个回答

11

a) 你是如何编译它的?(使用ghc -O2吗?)

b) 使用哪个版本的GHC?

c) Data.Map 相当高效,但你可能会被迫进行惰性更新——使用 insertWith' 而不是 insertWithKey。

d) 不要将 bytestrings 转换为 String。保持它们作为 bytestrings,并将它们存储在 Map 中。


版本号是6.12.1。在您的帮助下,我能够从运行时间中挤出1秒钟,但它仍然远远落后于Python版本。 - Masse

9

Data.Map是基于假设类Ord比较需要恒定时间的情况下设计的。对于字符串键可能不是这种情况,当字符串相等时就绝不是这种情况。根据你的语料库大小和有多少单词具有共同前缀,你可能会遇到这个问题。

我建议尝试一种设计用于操作序列键的数据结构,例如bytestring-trie包,由Don Stewart友情推荐。


@don:感谢更新。我相信你至少能够以名称的形式了解 Hackage 内容的 60% :-) - Norman Ramsey

7

我尝试避免做任何花哨或微妙的事情。这只是两种分组方法中的两种方法;第一种强调模式匹配,第二种则不强调。

import Data.List (foldl')
import qualified Data.Map as M
import qualified Data.ByteString.Char8 as B

type Database2 = M.Map (B.ByteString, B.ByteString) (M.Map B.ByteString Int)

train2 :: [B.ByteString] -> Database2
train2 words = go words M.empty
    where go (x:y:[]) m = m
          go (x:y:z:xs) m = let addWord Nothing   = Just $ M.singleton z 1
                                addWord (Just m') = Just $ M.alter inc z m'
                                inc Nothing    = Just 1
                                inc (Just cnt) = Just $ cnt + 1
                            in go (y:z:xs) $ M.alter addWord (x,y) m

train3 :: [B.ByteString] -> Database2
train3 words = foldl' update M.empty (zip3 words (drop 1 words) (drop 2 words))
    where update m (x,y,z) = M.alter (addWord z) (x,y) m
          addWord word = Just . maybe (M.singleton word 1) (M.alter inc word)
          inc = Just . maybe 1 (+1)

main = do contents <- B.readFile "76.txt"
          let db = train3 $ B.words contents
          print $ "Built a DB of " ++ show (M.size db) ++ " words"

我认为它们都比原始版本快,但诚实地说,我只是在第一个合理的语料库中测试过它们。

编辑 根据Travis Brown非常有道理的观点,

train4 :: [B.ByteString] -> Database2
train4 words = foldl' update M.empty (zip3 words (drop 1 words) (drop 2 words))
    where update m (x,y,z) = M.insertWith (inc z) (x,y) (M.singleton z 1) m
          inc k _ = M.insertWith (+) k 1

作为一种风格,我认为在这里使用比“修改”更具体的东西会更好。我们知道在这种情况下永远不需要删除,而像这样添加“Just”会影响可读性。 - Travis Brown
抱歉回复晚了。您能否解释一下为什么这是更快的解决方案?基本上两者都是相同的,除了压缩和删除。 - Masse

3
这是一个基于foldl'的版本,看起来比你的train快大约两倍:
train' :: [B.ByteString] -> Database
train' xs = foldl' (flip f) M.empty $ zip3 xs (tail xs) (tail $ tail xs)
  where
    f (a, b, c) = M.insertWith (M.unionWith (+)) (a, b) (M.singleton c 1)

我在Project Gutenberg的《哈克贝利·费恩历险记》上尝试了一下,得到了与你的函数相同的输出结果。虽然我的时间比较不够科学严谨,但这种方法可能值得一试。


2

1) 我不太清楚你的代码。 a)你定义了“fox”,但没有使用它。你是想让我们尝试使用“fox”而不是读取文件来帮助你吗? b)你将其声明为“module Markov”,然后在模块中有一个“main”。 c)System.Random不是必需的。如果您在发布之前稍微清理一下代码,可以帮助我们帮助您。

2)像Don所说,使用ByteStrings和一些严格的操作。

3)使用-O2编译并使用-fforce-recomp确保您实际重新编译了代码。

4)尝试这个轻微的转换,它非常快(0.005秒)。显然,输入非常小,因此您需要提供您的文件或自己进行测试。

{-# LANGUAGE OverloadedStrings, BangPatterns #-}
module Main where

import qualified Data.Map as M
import qualified Data.ByteString.Lazy.Char8 as B


type Database = M.Map (B.ByteString, B.ByteString) (M.Map B.ByteString Int)

train :: [B.ByteString] -> Database
train xs = go xs M.empty
  where
  go :: [B.ByteString] -> Database -> Database
  go (x:y:[]) !m = m
  go (x:y:z:xs) !m =
     let m' =  M.insertWithKey' (\key new old -> M.insertWithKey' (\_ n o -> n + 1) z 1 old) (x, y) (M.singleton z 1) m
     in go (y:z:xs) m'

main = print $ train $ B.words fox

fox="The quick brown fox jumps over the brown fox who is slow jumps over the brown fox who is dead."

好的,是的,就像标签上说的那样,我是一个初学者:P。我没有意识到将模块命名为除Main以外的其他名称的后果。而且狐狸是用来测试算法的。检查小输入比整本书的输入更容易。 - Masse

1

正如Don所建议的那样,尝试使用您的函数的更严格版本:insertWithKey'(和M.insertWith',因为您无论如何第二次都会忽略key参数)。

看起来您的代码可能会在到达[String]末尾之前积累大量的thunk。

请查看:http://book.realworldhaskell.org/read/profiling-and-optimization.html

...特别是尝试绘制堆(在本章中间部分左右)。很想看看您发现了什么。


我已经按照Don Stewart的建议进行了更改。以前的代码占用41-44兆字节的内存,现在只占用29兆字节。将内存制图显示,TSO占用大部分内存,然后是GHC.types,接着是代码中使用的其他数据类型。所有部分的内存在一秒钟内迅速增加。那一秒钟后,TSO和GHC.types仍在持续增加,其他所有部分开始缓慢减少。(如果我观察图表正确的话) - Masse

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接