Keras + Tensorflow和Python中的多进程技术

45
我正在使用Tensorflow作为后端,结合Keras进行开发。
我想在主进程中保存模型,然后在另一个进程中加载/运行模型(即调用`model.predict`)。
目前,我正在尝试文档中的简单方法来保存/加载模型:https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model。基本上:
  1. 在主进程中使用`model.save()`
  2. 在子进程中使用`model = load_model()`
  3. 在子进程中使用`model.predict()`
但是,在`load_model`调用处它会一直卡住。
我查找了一些资料,发现了这个可能相关的回答,该回答表明Keras只能在一个进程中使用:using multiprocessing with theano。但我不确定这是否正确(貌似没有太多关于此的信息)。
有什么方法可以实现我的目标吗?非常感谢提供高层描述或简短的示例。
注意:我已经尝试过将图形传递给进程的方法,但失败了,因为似乎TensorFlow图形无法pickle化(相关的SO帖子在这里:Tensorflow: Passing a session to a python multiprocess)。如果确实有一种方法可以将tensorflow图形/模型传递给子进程,那我也很愿意尝试。
谢谢!
3个回答

62

根据我的经验,问题出在一个进程中加载了Keras,然后在已经将Keras加载到主环境中的情况下派生一个新进程。但对于某些应用程序(例如训练一组Keras模型),将所有这些内容放在一个进程中更好。因此,我建议采用以下方法(有点繁琐,但适合我):

  1. 不要将Keras加载到主环境中。如果要加载Keras/Theano/TensorFlow,请仅在函数环境中执行。例如,不要执行以下操作:

    import keras
    
    def training_function(...):
        ...
    

    但请执行以下操作:

    def training_function(...):
        import keras
        ...
    
  2. 将与每个模型相关的工作运行在单独的进程中:通常我会创建一些工作进程,用于执行任务(例如训练、调整和评分),并将它们运行在单独的进程中。这样做的好处是,当进程完成时,该进程使用的所有内存都会被完全释放。这有助于解决许多内存问题,通常在使用多进程或在一个进程中运行多个模型时会遇到。因此,这看起来像这样:

  3. def _training_worker(train_params):
        import keras
        model = obtain_model(train_params)
        model.fit(train_params)
        send_message_to_main_process(...)
    
    def train_new_model(train_params):
        training_process = multiprocessing.Process(target=_training_worker, args = train_params)
        training_process.start()
        get_message_from_training_process(...)
        training_process.join()
    

采用不同的方法是为不同模型动作准备不同的脚本。但这可能会导致内存错误,特别是当您的模型需要大量内存时。请注意,由于此原因,最好使您的执行严格顺序。


Marcin,非常感谢您的回答。附带问题:上面的SO问题源于我想要在多个GPU上并行预测一个模型。我之前问过一个SO问题,但没有得到解决:https://dev59.com/j1gQ5IYBdhLWcg3wqFsF。您能否试着解决一下?提前感谢您。如果需要更多细节,请告诉我。 - John Cast
我会看一下。这也是我感兴趣的东西,但我还没有实现这样的解决方案。我实现了一个类似守护进程的过程,通过keras深度CNN的组合接收要处理的图像,这就是我的例子来源。 - Marcin Możejko
创建子进程并没有帮助我解决这个问题。当我中断keras模型拟合时,会出现很多挂起的进程。在调查过程中,我发现多进程池没有被try-except和池终止包装。https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py#L548 我认为这是keras的问题,但我不确定。 - luminousmen
在函数内部使用importlib.reload(keras),但它没有起作用。 - Kermit

8

1
我创建了一个修复我的代码的装饰器。
from multiprocessing import Pipe, Process

def child_process(func):
    """Makes the function run as a separate process."""
    def wrapper(*args, **kwargs):
        def worker(conn, func, args, kwargs):
            conn.send(func(*args, **kwargs))
            conn.close()
        parent_conn, child_conn = Pipe()
        p = Process(target=worker, args=(child_conn, func, args, kwargs))
        p.start()
        ret = parent_conn.recv()
        p.join()
        return ret
return wrapper

@child_process
def keras_stuff():
    """ Keras stuff here"""

这里有很多内容。首先是在Python中创建一个装饰器,这使我能够创建具有相似行为的新函数。对于此装饰器,其行为是作为子进程运行。如果您以前没有创建过装饰器,请在此处阅读更多信息:https://realpython.com/primer-on-python-decorators/ 除此之外,工作函数只需使用Python多进程执行修饰函数“func”。此代码在继续之前等待子进程完成。请在此处了解有关“multiprocessing”的更多信息:https://docs.python.org/3/library/multiprocessing.html - Mark

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接