追踪joblib.Parallel执行的进度

62

有没有一种简单的方法来跟踪joblib.Parallel执行的总进度?

我有一个长时间运行的执行组成了数千个任务,我想要跟踪和记录在数据库中。然而,为了做到这一点,每当Parallel完成一个任务时,我需要它执行回调,报告还剩下多少个任务。

我之前用Python的stdlib multiprocessing.Pool完成了类似的任务,通过启动一个线程来记录Pool作业列表中待处理作业的数量。

看着代码,Parallel继承了Pool,所以我认为我可以使用相同的技巧,但它似乎没有使用这些列表,并且我还没有能够弄清楚如何以其他方式"读取"它的内部状态。

11个回答

86

从dano和Connor的答案中再往前迈出一步的方法是将整个内容包装为上下文管理器:

import contextlib
import joblib
from tqdm import tqdm

@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
    """Context manager to patch joblib to report into tqdm progress bar given as argument"""
    class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
        def __call__(self, *args, **kwargs):
            tqdm_object.update(n=self.batch_size)
            return super().__call__(*args, **kwargs)

    old_batch_callback = joblib.parallel.BatchCompletionCallBack
    joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
    try:
        yield tqdm_object
    finally:
        joblib.parallel.BatchCompletionCallBack = old_batch_callback
        tqdm_object.close()

然后你可以像这样使用它,完成后不要留下猴子补丁代码:

from math import sqrt
from joblib import Parallel, delayed

with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
    Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))

我认为这很棒,它看起来类似于tqdm pandas集成。


1
优秀的解决方案。已经测试过,可以与joblib 0.14.1和tqdm 4.41.0兼容 - 运行良好。这将是tqdm的一个很好的补充! - dennisobrien
5
我无法进行编辑,但是解决方案中有一个小错别字,joblib.parallel.BatchCompletionCallback 实际上应该是 BatchCompletionCallBack(注意 CallBack 是驼峰式写法)。 - Andrew
3
我刚刚将这段代码发布到PyPI:https://github.com/louisabraham/tqdm_joblib现在您只需运行 pip install tqdm_joblibfrom tqdm_joblib import tqdm_joblib 即可。 - Labo
2
我认为这个不再起作用了。 - Ansh David

28

为什么你不能简单地使用tqdm?下面这个方法对我有效:

from joblib import Parallel, delayed
from datetime import datetime
from tqdm import tqdm

def myfun(x):
    return x**2

results = Parallel(n_jobs=8)(delayed(myfun)(i) for i in tqdm(range(1000))
100%|██████████| 1000/1000 [00:00<00:00, 10563.37it/s]

59
我认为这实际上并没有监控正在运行的作业的完成情况,只是在排队的作业。如果你在 myfun 开始处插入 time.sleep(1),你会发现 tqdm 进度条几乎瞬间就完成了,但是 results 需要多几秒钟才能填充。 - Noah
6
是的,那部分是正确的。它正在跟踪工作开始与完成的情况,但另一个问题是,在所有工作完成后,还会有一些开销导致延迟。一旦所有任务完成,需要收集结果,这可能需要相当长的时间。 - Jon
3
我认为这个回答并没有真正回答问题。正如提到的那样,这种方法将跟踪“排队”,而不是执行本身。下面展示的使用回调函数的方法似乎更准确地回答了这个问题。 - devforfu
12
这个回答是不正确的,因为它没有回答问题。这个回答应该被取消接受。 - Henry Henrinson
2
虽然这个答案在技术上确实是错误的,正如一些评论所指出的那样,但它仍然有用:它是最简单的解决方案,比其他答案要容易得多,并且当我使用它处理大量的短任务时,完成时间不会太长,所以在某些情况下它可能已经足够好了。 - someone
显示剩余5条评论

22

您提供的文档指出Parallel有一个可选的进度条。它是通过使用multiprocessing.Pool.apply_async提供的callback关键字参数来实现的:

# This is inside a dispatch function
self._lock.acquire()
job = self._pool.apply_async(SafeFunction(func), args,
            kwargs, callback=CallBack(self.n_dispatched, self))
self._jobs.append(job)
self.n_dispatched += 1

...

class CallBack(object):
    """ Callback used by parallel: it is used for progress reporting, and
        to add data to be processed
    """
    def __init__(self, index, parallel):
        self.parallel = parallel
        self.index = index

    def __call__(self, out):
        self.parallel.print_progress(self.index)
        if self.parallel._original_iterable:
            self.parallel.dispatch_next()

这里是print_progress函数:

def print_progress(self, index):
    elapsed_time = time.time() - self._start_time

    # This is heuristic code to print only 'verbose' times a messages
    # The challenge is that we may not know the queue length
    if self._original_iterable:
        if _verbosity_filter(index, self.verbose):
            return
        self._print('Done %3i jobs       | elapsed: %s',
                    (index + 1,
                     short_format_time(elapsed_time),
                    ))
    else:
        # We are finished dispatching
        queue_length = self.n_dispatched
        # We always display the first loop
        if not index == 0:
            # Display depending on the number of remaining items
            # A message as soon as we finish dispatching, cursor is 0
            cursor = (queue_length - index + 1
                      - self._pre_dispatch_amount)
            frequency = (queue_length // self.verbose) + 1
            is_last_item = (index + 1 == queue_length)
            if (is_last_item or cursor % frequency):
                return
        remaining_time = (elapsed_time / (index + 1) *
                    (self.n_dispatched - index - 1.))
        self._print('Done %3i out of %3i | elapsed: %s remaining: %s',
                    (index + 1,
                     queue_length,
                     short_format_time(elapsed_time),
                     short_format_time(remaining_time),
                    ))

他们实现这个的方式有点奇怪,说实话——它似乎假定任务总是按照开始的顺序完成。传递给 print_progress 函数的 index 变量只是在实际启动作业时的 self.n_dispatched 变量。因此,即使第三个作业先完成,第一个启动的作业也始终会以 0 的 index 结束。这也意味着他们实际上没有跟踪已完成的工作数量。因此,你无法监视任何实例变量。

我认为你最好自己制作一个回调类,并对 Parallel 进行 monkey patch。

from math import sqrt
from collections import defaultdict
from joblib import Parallel, delayed

class CallBack(object):
    completed = defaultdict(int)

    def __init__(self, index, parallel):
        self.index = index
        self.parallel = parallel

    def __call__(self, index):
        CallBack.completed[self.parallel] += 1
        print("done with {}".format(CallBack.completed[self.parallel]))
        if self.parallel._original_iterable:
            self.parallel.dispatch_next()

import joblib.parallel
joblib.parallel.CallBack = CallBack

if __name__ == "__main__":
    print(Parallel(n_jobs=2)(delayed(sqrt)(i**2) for i in range(10)))

输出:

done with 1
done with 2
done with 3
done with 4
done with 5
done with 6
done with 7
done with 8
done with 9
done with 10
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]

这样,你的回调函数将在作业完成时被调用,而不是使用默认值。


1
非常好的研究,谢谢。我没有注意到回调属性。 - Cerin
我发现joblib的文档非常有限。我不得不深入源代码中的CallBack类。我的问题是:当调用__call__时,我能否自定义参数?(子类化整个Parallel类可能是一种方法,但对我来说太重了)。 - Ziyuan

11

针对最新版本的joblib库,进一步解释了dano的答案。内部实现进行了几处更改。

from joblib import Parallel, delayed
from collections import defaultdict

# patch joblib progress callback
class BatchCompletionCallBack(object):
  completed = defaultdict(int)

  def __init__(self, time, index, parallel):
    self.index = index
    self.parallel = parallel

  def __call__(self, index):
    BatchCompletionCallBack.completed[self.parallel] += 1
    print("done with {}".format(BatchCompletionCallBack.completed[self.parallel]))
    if self.parallel._original_iterator is not None:
      self.parallel.dispatch_next()

import joblib.parallel
joblib.parallel.BatchCompletionCallBack = BatchCompletionCallBack

9

简短解决方案:

适用于python 3.5版本,使用joblib 0.14.0和tqdm 4.46.0。感谢frenzykryger提供的contextlib建议,以及dano和Connor提出的monkey patching想法。

import contextlib
import joblib
from tqdm import tqdm
from joblib import Parallel, delayed

@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
    """Context manager to patch joblib to report into tqdm progress bar given as argument"""

    def tqdm_print_progress(self):
        if self.n_completed_tasks > tqdm_object.n:
            n_completed = self.n_completed_tasks - tqdm_object.n
            tqdm_object.update(n=n_completed)

    original_print_progress = joblib.parallel.Parallel.print_progress
    joblib.parallel.Parallel.print_progress = tqdm_print_progress

    try:
        yield tqdm_object
    finally:
        joblib.parallel.Parallel.print_progress = original_print_progress
        tqdm_object.close()

您可以按照 frenzykryger 描述的方式使用它。
import time
def some_method(wait_time):
    time.sleep(wait_time)

with tqdm_joblib(tqdm(desc="My method", total=10)) as progress_bar:
    Parallel(n_jobs=2)(delayed(some_method)(0.2) for i in range(10))
更长的解释: Jon提供的解决方案实现简单,但它只测量已分派的任务。如果任务执行时间很长,进度条会在等待最后一个分派的任务完成执行时停留在100%。
frenzykryger改进了dano和Connor的上下文管理器方法,这是更好的方法,但BatchCompletionCallBack也可能在任务完成之前使用ImmediateResult调用(请参见Intermediate results from joblib)。这将导致我们得到一个超过100%的计数。
我们可以不使用猴子补丁来修补BatchCompletionCallBack,而只需修补Parallel中的print_progress函数。因为BatchCompletionCallBack已经调用了此print_progress。如果设置了verbose(即Parallel(n_jobs=2,verbose=100)),则print_progress将打印已完成的任务,但不如tqdm那样美观。查看代码,print_progress是一个类方法,因此它已经有self.n_completed_tasks记录我们想要的数值。我们只需将其与joblib进度的当前状态进行比较,并仅在存在差异时更新即可。
这在使用python 3.5的joblib 0.14.0和tqdm 4.46.0中进行了测试。

4

文本进度条

对于那些想要不使用额外模块(如tqdm)的文本进度条的人来说,这是另一种选择。适用于joblib=0.11、python 3.5.2和linux系统(截至2018年4月16日),并在子任务完成时显示进度。

重新定义原生类:

class BatchCompletionCallBack(object):
    # Added code - start
    global total_n_jobs
    # Added code - end
    def __init__(self, dispatch_timestamp, batch_size, parallel):
        self.dispatch_timestamp = dispatch_timestamp
        self.batch_size = batch_size
        self.parallel = parallel

    def __call__(self, out):
        self.parallel.n_completed_tasks += self.batch_size
        this_batch_duration = time.time() - self.dispatch_timestamp

        self.parallel._backend.batch_completed(self.batch_size,
                                           this_batch_duration)
        self.parallel.print_progress()
        # Added code - start
        progress = self.parallel.n_completed_tasks / total_n_jobs
        print(
            "\rProgress: [{0:50s}] {1:.1f}%".format('#' * int(progress * 50), progress*100)
            , end="", flush=True)
        if self.parallel.n_completed_tasks == total_n_jobs:
            print('\n')
        # Added code - end
        if self.parallel._original_iterator is not None:
            self.parallel.dispatch_next()

import joblib.parallel
import time
joblib.parallel.BatchCompletionCallBack = BatchCompletionCallBack

在使用之前定义全局常量,包含所有工作的总数:

total_n_jobs = 10

这将导致类似以下的结果:
Progress: [########################################          ] 80.0%

1
非常好。如果您还想打印时间估计,可以使用以下方式调整__call__time_remaining = (this_batch_duration / self.batch_size) * (total_n_jobs - self.parallel.n_completed_tasks) print( "\r进度:[{0:50s}] {1:.1f}% 预计剩余{2:1f}分钟".format('#' * int(progress * 50), progress*100, time_remaining/60) , end="", flush=True) - lawrencegripper

4

截至于2023年6月发布的joblib v1.3.0版本,有一种更简单的方法可以使用tqdm进度条包装joblib.Parallel(受this comment启发)。

这个进度条将跟踪作业完成情况,而不是作业加入队列。以前,这需要一个特殊的上下文管理器。以下是一个示例:

from joblib import Parallel, delayed
from tqdm import tqdm

import time
import random

# Our example worker will sleep for a certain number of seconds.

inputs = list(range(10))
random.shuffle(inputs)

def worker(n_seconds):
    time.sleep(n_seconds)
    return n_seconds

# Run the worker jobs in parallel, with a tqdm progress bar.
# We configure Parallel to return a generator.
# Then we wrap the generator in tqdm.
# Finally, we execute everything by converting the tqdm generator to a list.

outputs = list(
    tqdm(
        # Note the new return_as argument here, which requires joblib >= 1.3:
        Parallel(return_as="generator", n_jobs=3)(
            delayed(worker)(n_seconds) for n_seconds in inputs
        ),
        total=len(inputs),
    )
)
print(outputs)

1
太棒了!我认为现在joblib v1.3.0已经发布了,这可能应该成为被接受的答案。它的效果很好,比其他解决方案简单得多。 - undefined

1

这是另一个回答你问题的语法:

aprun = ParallelExecutor(n_jobs=5)

a1 = aprun(total=25)(delayed(func)(i ** 2 + j) for i in range(5) for j in range(5))
a2 = aprun(total=16)(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
a2 = aprun(bar='txt')(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
a2 = aprun(bar=None)(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))

https://dev59.com/V1oU5IYBdhLWcg3wHEVp#40415477


0
import joblib
class ProgressParallel(joblib.Parallel):
    def __init__(self, n_total_tasks=None, **kwargs):
        super().__init__(**kwargs)
        self.n_total_tasks = n_total_tasks

    def __call__(self, *args, **kwargs):
        with tqdm() as self._pbar:
            return joblib.Parallel.__call__(self, *args, **kwargs)

    def print_progress(self):
        if self.n_total_tasks:
            self._pbar.total = self.n_total_tasks
        else:
            self._pbar.total = self.n_dispatched_tasks
        self._pbar.n = self.n_completed_tasks
        self._pbar.refresh()

你介意在你的代码中加入一些解释吗? - Lover of Structure

0

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接