Python多进程跳过子进程段错误

3
我正在尝试使用多进程处理一个可能会返回段错误的函数(目前我无法控制)。在子进程遇到段错误的情况下,我希望只有该子进程失败,而其他所有子任务都能够继续/返回它们的结果。
我已经从`multiprocessing.Pool`切换到`concurrent.futures.ProcessPoolExecutor`,以避免子进程挂起(或直到任意超时)的问题,正如这个bug中所记录的那样:https://bugs.python.org/issue22393
然而,我现在面临的问题是,当第一个子任务遇到段错误时,所有正在运行的子进程都被标记为损坏(`concurrent.futures.process.BrokenProcessPool`)。
是否有一种方法只将实际上损坏的子进程标记为损坏?
我在`Python 3.7.4`中运行的代码:
import concurrent.futures
import ctypes
from time import sleep


def do_something(x):
    print(f"{x}; in do_something")
    sleep(x*3)
    if x == 2:
        # raise a segmentation fault internally
        return x, ctypes.string_at(0)
    return x, x-1


nums = [1, 2, 3, 1.5]
executor = concurrent.futures.ProcessPoolExecutor()
result_futures = []
for num in nums:
    # Using submit with a list instead of map lets you get past the first exception
    # Example: https://dev59.com/m1UK5IYBdhLWcg3w9jnb#53346191
    future = executor.submit(do_something, num)
    result_futures.append(future)

# Wait for all results
concurrent.futures.wait(result_futures)

# After a segfault is hit for any child process (i.e. is "terminated abruptly"), the process pool becomes unusable
# and all running/pending child processes' results are set to broken
for future in result_futures:
    try:
        print(future.result())
    except concurrent.futures.process.BrokenProcessPool:
        print("broken")

结果:

(1, 0)
broken
broken
(1.5, 0.5)

期望的结果:

(1, 0)
broken
(3, 2)
(1.5, 0.5)
2个回答

3

multiprocessing.Poolconcurrent.futures.ProcessPoolExecutor假定如��处理工作进程和主进程之间的并发交互,如果任何一个进程被杀死或发生段错误,则违反了这些假设,因此它们会采取安全措施并将整个池标记为损坏。为了解决这个问题,您需要使用不同假设直接使用multiprocessing.Process实例构建自己的进程池。

这可能听起来令人畏惧,但是listmultiprocessing.Manager可以帮助您解决大部分问题:

import multiprocessing
import ctypes
import queue
from time import sleep

def do_something(job, result):
    while True:
        x=job.get()
        print(f"{x}; in do_something")
        sleep(x*3)
        if x == 2:
            # raise a segmentation fault internally
            return x, ctypes.string_at(0)
        result.put((x, x-1))

nums = [1, 2, 3, 1.5]

if __name__ == "__main__":
    # you ARE using the spawn context, right?
    ctx = multiprocessing.get_context("spawn")
    manager = ctx.Manager()
    job_queue = manager.Queue(maxsize=-1)
    result_queue = manager.Queue(maxsize=-1)
    pool = [
        ctx.Process(target=do_something, args=(job_queue, result_queue), daemon=True)
        for _ in range(multiprocessing.cpu_count())
    ]
    for proc in pool:
        proc.start()
    for num in nums:
        job_queue.put(num)
    try:
        while True:
            # Timeout is our only signal that no more results coming
            print(result_queue.get(timeout=10))
    except queue.Empty:
        print("Done!")
    print(pool)  # will see one dead Process 
    for proc in pool:
        proc.kill()  # avoid stderr spam

这个"池子"(Pool)有点不够灵活,你可能需要根据应用程序的特定需求进行自定义。但你肯定可以避免工作进程崩溃。
当我深入研究取消工作者进程(worker pool)中的特定提交时,最终写了一个整个库,结合Trio async应用程序使用:trio-parallel。希望你不需要走那么远!

1

根据@Richard Sheridan的回答,我最终使用了下面的代码。这个版本不需要设置超时时间,这是我的用例中无法做到的。

import ctypes
import multiprocessing
from typing import List
from time import sleep


def do_something(x, result):
    print(f"{x} starting")
    sleep(x * 3)
    if x == 2:
        # raise a segmentation fault internally
        y = ctypes.string_at(0)
    y = x
    print(f"{x} done")
    results_queue.put(y)

def wait_for_process_slot(
    processes: List,
    concurrency: int = multiprocessing.cpu_count() - 1,
    wait_sec: int = 1,
) -> int:
    """Blocks main process if `concurrency` processes are already running.

    Alternative to `multiprocessing.Semaphore.acquire`
    useful for when child processes might fail and not be able to signal.
    Relies instead on the main's (parent's) tracking of `multiprocessing.Process`es.

    """
    counter = 0
    while True:
        counter = sum([1 for i, p in processes.items() if p.is_alive()])
        if counter < concurrency:
            return counter
        sleep(wait_sec)


if __name__ == "__main__":
    # "spawn" results in an OSError b/c pickling a segfault fails?
    ctx = multiprocessing.get_context()
    manager = ctx.Manager()
    results_queue = manager.Queue(maxsize=-1)

    concurrency = multiprocessing.cpu_count() - 1  # reserve 1 CPU for waiting
    nums = [3, 1, 2, 1.5]
    all_processes = {}
    for idx, num in enumerate(nums):
        num_running_processes = wait_for_process_slot(all_processes, concurrency)

        p = ctx.Process(target=do_something, args=(num, results_queue), daemon=True)
        all_processes.update({idx: p})
        p.start()

    # Wait for the last batch of processes not blocked by wait_for_process_slot to finish
    for p in all_processes.values():
        p.join()

    # Check last batch of processes for bad processes
    # Relies on all processes having finished (the p.joins above)
    bad_nums = [idx for idx, p in all_processes.items() if p.exitcode != 0]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接