我正在使用Python和Keras (目前使用Theano后端,但我没有更改的问题)。 我有一个神经网络,在并行处理多个信息源。 目前,我一直在单独的进程中运行每个信息源,并从文件中加载其自己的网络副本。这似乎浪费了RAM,因此我认为使用一个单一的多线程进程,其中只有一个网络实例被所有线程使用,将更有效率。但是,我想知道Keras是否与任何后端都是线程安全的。如果我在不同的线程中同时运行.predict(x) 两个不同输入,是否会遇到竞争条件或其他问题?谢谢。
是的,Keras是线程安全的,只要您稍微注意一下。
实际上,在强化学习中有一个算法叫做Asynchronous Advantage Actor Critics (A3C),每个代理都依赖于同一神经网络告诉他们在给定状态下应该做什么。换句话说,每个线程同时调用model.predict
,就像您的问题一样。其中使用Keras的示例实现在这里。
然而,如果您查看了代码,应该特别注意这行代码:
model._make_predict_function() # have to initialize before threading
Keras文档中从未提到过这一点,但在并发环境中必须这样做才能使其正常工作。简而言之,_make_predict_function
是编译predict
函数的函数。在多线程设置中,您必须手动调用此函数以预先编译predict
,否则predict
函数将不会被编译,直到您第一次运行它时,这将在许多线程同时调用时出现问题。您可以在这里看到详细的解释。
到目前为止,我还没有遇到Keras中的其他多线程问题。
引用友好的fcholet的话:
_make_predict_function是一个私有API,我们不应该推荐调用它。
在这里,用户应该先调用predict。
请注意,Keras模型不能保证线程安全。对于CPU推断,请考虑在每个线程中拥有独立的模型副本。
参考:https://dev59.com/NtX8oIgBc1ULPQZFGtOO#75606666
在调用predict()
方法之前复制/克隆模型是可行的。
model = ...
_model = tensorflow.keras.models.clone_model(model)
predictions = _model.predict(...)