Keras是线程安全的吗?

25
我正在使用Python和Keras (目前使用Theano后端,但我没有更改的问题)。 我有一个神经网络,在并行处理多个信息源。 目前,我一直在单独的进程中运行每个信息源,并从文件中加载其自己的网络副本。这似乎浪费了RAM,因此我认为使用一个单一的多线程进程,其中只有一个网络实例被所有线程使用,将更有效率。但是,我想知道Keras是否与任何后端都是线程安全的。如果我在不同的线程中同时运行.predict(x) 两个不同输入,是否会遇到竞争条件或其他问题?谢谢。
3个回答

27

是的,Keras是线程安全的,只要您稍微注意一下。

实际上,在强化学习中有一个算法叫做Asynchronous Advantage Actor Critics (A3C),每个代理都依赖于同一神经网络告诉他们在给定状态下应该做什么。换句话说,每个线程同时调用model.predict,就像您的问题一样。其中使用Keras的示例实现在这里

然而,如果您查看了代码,应该特别注意这行代码: model._make_predict_function() # have to initialize before threading

Keras文档中从未提到过这一点,但在并发环境中必须这样做才能使其正常工作。简而言之,_make_predict_function 是编译predict函数的函数。在多线程设置中,您必须手动调用此函数以预先编译predict,否则predict函数将不会被编译,直到您第一次运行它时,这将在许多线程同时调用时出现问题。您可以在这里看到详细的解释。

到目前为止,我还没有遇到Keras中的其他多线程问题。


2
我认为Daniel的回答更准确和更新。 - asaf92

8

引用友好的fcholet的话:

_make_predict_function是一个私有API,我们不应该推荐调用它。

在这里,用户应该先调用predict。

请注意,Keras模型不能保证线程安全。对于CPU推断,请考虑在每个线程中拥有独立的模型副本。


0

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接