使用multiprocessing.Pool()比仅使用普通函数慢

58

(这个问题是关于如何让multiprocessing.Pool()运行代码更快的。我最终解决了这个问题,最终解决方案可以在本帖底部找到。)

原问题:

我正在尝试使用Python将一个单词与列表中的许多其他单词进行比较,并检索出最相似的单词列表。为此,我使用difflib.get_close_matches函数。我使用的是相对较新且功能强大的Windows 7笔记本电脑,配备Python 2.6.5。

我想要加速比较过程,因为我的单词比较列表非常长,而我必须多次重复比较过程。当我听说有关多进程模块时,如果可以将比较分解为工作任务并同时运行(从而利用机器功率以换取更快的速度),我的比较任务将更快地完成,这似乎很合理。

但是,即使尝试了许多不同的方式,并使用已经显示在文档和论坛帖子中的方法,Pool方法似乎仍然非常缓慢,比一次在整个列表上运行原始get_close_match函数要慢得多。我想要理解为什么Pool()如此缓慢,以及我是否使用它正确。我仅将这个字符串比较方案用作示例,因为那是我最近无法理解或使多处理工作的示例,而不是反对我。下面只是difflib场景中的示例代码,显示了普通方法和Pooled方法之间的时间差异:

from multiprocessing import Pool
import random, time, difflib

# constants
wordlist = ["".join([random.choice([letter for letter in "abcdefghijklmnopqersty"]) for lengthofword in xrange(5)]) for nrofwords in xrange(1000000)]
mainword = "hello"

# comparison function
def findclosematch(subwordlist):
    matches = difflib.get_close_matches(mainword,subwordlist,len(subwordlist),0.7)
    if matches <> []:
        return matches

# pool
print "pool method"
if __name__ == '__main__':
    pool = Pool(processes=3)
    t=time.time()
    result = pool.map_async(findclosematch, wordlist, chunksize=100)
    #do something with result
    for r in result.get():
        pass
    print time.time()-t

# normal
print "normal method"
t=time.time()
# run function
result = findclosematch(wordlist)
# do something with results
for r in result:
    pass
print time.time()-t

要找的单词是“hello”,要查找的单词列表是一个由5个随机字符组成的100万项长列表(仅用于说明目的)。我使用3个处理器核心和映射函数,块大小为100(我认为是每个工作程序要处理的列表项?)(我还尝试过1000和10000的块大小,但没有真正的区别)。请注意,在两种方法中,我都在调用函数之前启动计时器,并在循环遍历结果后立即结束计时器。如下所示,计时结果明显有利于原始的非池方法:
>>> 
pool method
37.1690001488 seconds
normal method
10.5329999924 seconds
>>> 

池方法比原始方法慢近4倍。我是否遗漏了什么,或者对池化/多进程工作方式的理解可能存在误解?我怀疑问题的一部分可能是map函数返回None,因此即使我只想返回实际匹配项并已在函数中编写为此类结果,它也会向结果列表添加数千个不必要的项。据我所知,这就是map的工作方式。我听说过一些其他函数(如filter)只收集非False结果,但我认为multiprocessing/Pool不支持filter方法。除了multiprocessing模块中的map/imap之外,还有其他函数可以帮助我仅返回函数返回的内容吗?根据我的理解,apply函数更适用于提供多个参数。
我知道还有imap函数,但没有任何时间改进。原因是我遇到了理解itertools模块有多好的同样问题,它被认为是“闪电般快速”的,我注意到调用函数的确很快,但根据我的经验和我所读的,那是因为调用函数实际上并未进行任何计算,因此当需要迭代结果以收集和分析它们时(如果没有这些,调用该函数就没有意义),它所需的时间与直接使用函数的普通版本相同或甚至更长。但我想这是另一个帖子的问题了。
无论如何,很高兴看到有人能在正确的方向上给我指引,并真的非常感谢任何帮助。我更感兴趣的是理解多进程工作方式,而不是让这个示例正常工作,尽管一些示例解决方案代码建议对我的理解会有所帮助。
答案:
似乎减速与附加进程的缓慢启动时间有关。我无法使.Pool()函数足够快。我最终的解决方案是手动拆分工作负载列表,使用多个.Process()而不是.Pool(),并在队列中返回解决方案。但我想知道最关键的变化可能是将工作负载拆分为要查找的主要单词,而不是要进行比较的单词,可能是因为difflib搜索函数已经非常快了。这是新代码同时运行5个进程,结果比运行简单代码(6秒 vs 55秒)快约x10。在difflib已经非常快的情况下,非常适用于快速模糊查找。
from multiprocessing import Process, Queue
import difflib, random, time

def f2(wordlist, mainwordlist, q):
    for mainword in mainwordlist:
        matches = difflib.get_close_matches(mainword,wordlist,len(wordlist),0.7)
        q.put(matches)

if __name__ == '__main__':

    # constants (for 50 input words, find closest match in list of 100 000 comparison words)
    q = Queue()
    wordlist = ["".join([random.choice([letter for letter in "abcdefghijklmnopqersty"]) for lengthofword in xrange(5)]) for nrofwords in xrange(100000)]
    mainword = "hello"
    mainwordlist = [mainword for each in xrange(50)]

    # normal approach
    t = time.time()
    for mainword in mainwordlist:
        matches = difflib.get_close_matches(mainword,wordlist,len(wordlist),0.7)
        q.put(matches)
    print time.time()-t

    # split work into 5 or 10 processes
    processes = 5
    def splitlist(inlist, chunksize):
        return [inlist[x:x+chunksize] for x in xrange(0, len(inlist), chunksize)]
    print len(mainwordlist)/processes
    mainwordlistsplitted = splitlist(mainwordlist, len(mainwordlist)/processes)
    print "list ready"

    t = time.time()
    for submainwordlist in mainwordlistsplitted:
        print "sub"
        p = Process(target=f2, args=(wordlist,submainwordlist,q,))
        p.Daemon = True
        p.start()
    for submainwordlist in mainwordlistsplitted:
        p.join()
    print time.time()-t
    while True:
        print q.get()

你尝试过增加块大小吗?比如说chunksize=100000之类的? - Hannes Ovrén
1
然后更改调用,使findclosematch()做更多的工作。否则,pickling / unpickling参数将占主导地位。 - jfs
如果 matches <> [],那么这种写法很糟糕,请使用 if matches: 代替。 - jfs
1
不要使用<>,它已经被废弃了很长时间,在Python3中会引发SyntaxError,因此使用它会使代码变得不太向前兼容。请注意,生成进程和进程间通信的成本非常高。如果您想通过多个进程减少时间,必须确保计算时间足够长,以便开销不重要。在您的情况下,我认为这并不是真的。 - Bakuriu
1
同时,if matches: 检查是完全无用的,可能会产生错误。我刚试着运行脚本,修改了一些参数,由于这个虚假检查而得到了一个 TypeError: NoneType object is not iterable 错误。99.9% 的情况下,函数应该始终返回相同的结果。不要使用 None 特殊处理空结果,因为这只会使代码中其他部分处理函数结果变得更加复杂。 - Bakuriu
显示剩余4条评论
4个回答

48
这些问题通常可以归结为以下几点:

你尝试并行化的函数没有足够的CPU资源(即CPU时间)来支持并行化!

当你使用 multiprocessing.Pool(8) 进行并行化时,你理论上可以获得 8倍 的加速效果,但实际情况可能不是这样。

然而,请记住这不是免费的 - 你要付出以下开销才能获得并行化的好处:

  1. 为传递给 Pool.map(f, iter)iter 中的每个 chunk(大小为 chunksize)创建一个 task
  2. 对于每个 task
    1. tasktask 的返回值序列化(类似于 pickle.dumps())
    2. 反序列化 tasktask 的返回值(类似于 pickle.loads())
    3. 在工作进程和父进程从/向这些队列获取/放置元素时,浪费大量时间等待共享内存队列上的锁(例如Locks)。
  3. 每个工作进程调用 os.fork() 的一次性开销。这是很昂贵的。

实际上,当你使用 Pool() 时,你需要:

  1. 高CPU资源要求
  2. 传递给每个函数调用的数据占用内存小
  • 合理长度的 iter 以证明上述(3)的一次性成本。
  • 如需更深入的探讨,请参考此帖子和相关演讲,阐述了将大型数据传递给Pool.map()及其伙伴)会导致问题。

    Raymond Hettinger 在这里也谈到了如何正确使用 Python 的并发。


    2
    请注意,上面的链接引用了我的Python波士顿用户组演讲和博客文章。 - The Aelfinn

    18

    我猜测问题出在进程间通信(IPC)开销上。在单进程实例中,单个进程拥有单词列表。当委派给其他进程时,主进程需要不断地将列表的部分区域传输到其他进程。

    因此,更好的方法可能是启动 n 个进程,每个进程负责加载/生成列表的 1/n 部分,并检查该部分列表中是否包含该单词。

    但我不确定如何使用Python的multiprocessing库实现这一点。


    1
    我同意并怀疑有类似进程启动时间和通信的问题在拖慢我的脚本。最终,我使用了multiprocessing.Process函数,这使我能够手动分割列表并实现10倍的性能提升。请查看我更新的帖子以获取我使用的新代码。 - Karim Bahgat

    5

    我在另一个问题上也遇到了与池类似的情况。目前我不确定实际原因是什么...

    答案由OP Karim Bahgat编辑,这个解决方案对我也起作用。切换到进程和队列系统后,我能够看到与机器核心数相符合的加速效果。

    以下是一个例子。

    def do_something(data):
        return data * 2
    
    def consumer(inQ, outQ):
        while True:
            try:
                # get a new message
                val = inQ.get()
    
                # this is the 'TERM' signal
                if val is None:
                    break;
    
                # unpack the message
                pos = val[0]  # its helpful to pass in/out the pos in the array
                data = val[1]
    
                # process the data
                ret = do_something(data)
    
                # send the response / results
                outQ.put( (pos, ret) )
    
    
            except Exception, e:
                print "error!", e
                break
    
    def process_data(data_list, inQ, outQ):
        # send pos/data to workers
        for i,dat in enumerate(data_list):
            inQ.put( (i,dat) )
    
        # process results
        for i in range(len(data_list)):
            ret = outQ.get()
            pos = ret[0]
            dat = ret[1]
            data_list[pos] = dat
    
    
    def main():
        # initialize things
        n_workers = 4
        inQ = mp.Queue()
        outQ = mp.Queue()
        # instantiate workers
        workers = [mp.Process(target=consumer, args=(inQ,outQ))
                   for i in range(n_workers)]
    
        # start the workers
        for w in workers:
            w.start()
    
        # gather some data
        data_list = [ d for d in range(1000)]
    
        # lets process the data a few times
        for i in range(4):
            process_data(data_list)
    
        # tell all workers, no more data (one msg for each)
        for i in range(n_workers):
            inQ.put(None)
        # join on the workers
        for w in workers:
            w.join()
    
        # print out final results  (i*16)
        for i,dat in enumerate(data_list):
            print i, dat
    

    2

    Pool.map速度较慢,因为需要时间启动进程,并将必要的内存从一个进程传输到所有进程,正如Multimedia Mike所说。我也遇到了类似的问题,我转而使用了multiprocessing.Process

    但是multiprocessing.Process启动进程的时间比Pool.map长。

    解决方案:

    • 提前创建进程并将静态数据放入进程中。
    • 使用队列将数据传递给进程。
    • 还要使用队列从进程接收结果。

    这样,我在配有Windows的Core i5 8265U处理器的笔记本电脑上,在3秒内从100万个人脸特征中搜索到最佳匹配。

    代码 - multiprocess_queue_matcher.py:

    import multiprocessing
    
    from utils import utils
    
    no_of_processes = 0
    input_queues = []
    output_queues = []
    db_embeddings = []
    slices = None
    
    
    def set_data(no_of_processes1, input_queues1, output_queues1, db_embeddings1):
        global no_of_processes
        no_of_processes = no_of_processes1
        global input_queues
        input_queues = input_queues1
        global output_queues
        output_queues = output_queues1
        global db_embeddings
        print("db_embeddings1 size = " + str(len(db_embeddings1)))
        db_embeddings.extend(db_embeddings1)
        global slices
        slices = chunks()
    
    
    def chunks():
        size = len(db_embeddings) // no_of_processes
        return [db_embeddings[i:i + size] for i in range(0, len(db_embeddings), size)]
    
    
    def do_job2(slice, input_queue, output_queue):
        while True:
            emb_to_search = input_queue.get()
            dist1 = 2
            item1 = []
            data_slice = slice
            # emb_to_search = obj[1]
            for item in data_slice:
                emb = item[0]
                dist = utils.calculate_squared_distance(emb_to_search, emb)
                if dist < dist1:
                    dist1 = dist
                    item1 = item
                    item1.append(dist1)
            output_queue.put(item1)
        # if return_value is None:
        #     return item1
        # else:
        #     return_value.set_value(None, item1[1], item1[2], item1[3], item1[4], dist1)
    
    
    def submit_job(emb):
        for i in range(len(slices)):
            input_queues[i].put(emb)
    
    
    def get_output_queues():
        return output_queues
    
    
    def start_processes():
        # slice = self.chunks()
        # ctx = multiprocessing.get_context("spawn")
        # BaseManager.register('FaceData', FaceData)
        # manager = BaseManager()
        # manager.start()
        # return_values = []
        global no_of_processes
        global input_queues
        global output_queues
        processes = []
        pos = 0
        for i in range(no_of_processes):
            p = multiprocessing.Process(target=do_job2, args=(slices[i], input_queues[i], output_queues[i],))
            p.Daemon = True
            processes.append(p)
            pos += 1
            p.start()
    

    然后在需要的地方使用这个模块。

    Flask的高级启动代码:

    mysql = None
    
    db_operator = None
    
    all_db_embeddings = []
    
    input_queues = []
    output_queues = []
    no_of_processes = 4
    
    
    @app.before_first_request
    def initialize():
        global mysql
        global db_operator
        mysql = MySQL(app)
        db_operator = DBOperator(mysql)
        ret, db_embeddings, error_message = db_operator.get_face_data_for_all_face_ids_for_all_users()
        all_db_embeddings.extend(db_embeddings)
        for i in range(no_of_processes):
            in_q = multiprocessing.Queue()
            out_q = multiprocessing.Queue()
            input_queues.append(in_q)
            output_queues.append(out_q)
        multiprocess_queue_matcher.set_data(no_of_processes, input_queues, output_queues, all_db_embeddings)
        multiprocess_queue_matcher.start_processes()
    

    在任何请求端点上按需将作业传递给进程

    emb_to_match = all_db_embeddings[0][0]
        starttime = time.time()
        multiprocess_queue_matcher.submit_job(emb_to_match)
        outputs = []
        for i in range(no_of_processes):
            out_q = output_queues[i]
            outputs.append(out_q.get())
        max = [None, None, None, None, None, 2.0]
        for val in outputs:
            if val[5] < max[5]:
                max = val
        time_elapsed = time.time() - starttime
        return jsonify(
            {"status": "success", "message": "Face search completed", "best_match_faceid": max[1],
             "name": max[2], "distance": max[5], "search_time": time_elapsed})
    

    有关这段代码,您有什么建议和改进的意见吗?

    网页内容由stack overflow 提供, 点击上面的
    可以查看英文原文,
    原文链接