在多线程中重复使用TensorFlow会话会导致崩溃

7

背景:

我有一些复杂的强化学习算法,我想在多个线程中运行。

问题

尝试在线程中调用sess.run时,会出现以下错误消息:

RuntimeError: The Session graph is empty. Add operations to the graph before calling run().

代码重现错误:

import tensorflow as tf

import threading

def thread_function(sess, i):
    inn = [1.3, 4.5]
    A = tf.placeholder(dtype=float, shape=(None), name="input")
    P = tf.Print(A, [A])
    Q = tf.add(A, P)
    sess.run(Q, feed_dict={A: inn})

def main(sess):

    thread_list = []
    for i in range(0, 4):
        t = threading.Thread(target=thread_function, args=(sess, i))
        thread_list.append(t)
        t.start()

    for t in thread_list:
        t.join()

if __name__ == '__main__':

    sess = tf.Session()
    main(sess)

如果我在线程外运行相同的代码,它可以正常工作。

有人可以提供一些关于如何在Python线程中正确使用Tensorflow会话的见解吗?

2个回答

8

会话(Session)不仅可以作为当前线程的默认会话,还可以作为图形(Graph)的默认会话。 当您传入会话并在其上调用run时,默认图形将是另一个图形。

您可以像这样修改thread_function以使其正常工作:

def thread_function(sess, i):
    with sess.graph.as_default():
        inn = [1.3, 4.5]
        A = tf.placeholder(dtype=float, shape=(None), name="input")
        P = tf.Print(A, [A])
        Q = tf.add(A, P)
        sess.run(Q, feed_dict={A: inn})

然而,我不希望期望有任何显著的加速。在Python中,线程并不像其他一些语言中那样,只有像io这样的某些操作会并行运行。对于 CPU 密集型操作它并不是非常有用。多进程可以真正并行地运行代码,但你将无法共享同一个会话。


太好了,谢谢!我也找到了这个答案,它有效。我不是在寻求加速,而是想在单个环境中为不同的代理使用不同的模型。 - Andreas Pasternak

4

在de1的回答上,我提供另一个在github上的资源: tensorflow/tensorflow#28287 (comment)

以下内容对我解决了tf的多线程兼容性问题:

# on thread 1
session = tf.Session(graph=tf.Graph())
with session.graph.as_default():
    k.backend.set_session(session)
    model = k.models.load_model(filepath)

# on thread 2
with session.graph.as_default():
    k.backend.set_session(session)
    model.predict(x)

这将为其他线程保留SessionGraph。模型加载到它们的“上下文”中(而不是默认值),并且保留供其他线程使用。
(默认情况下,模型会加载到默认的Session和默认的Graph中)
另一个好处是它们保持在同一对象中 - 更易于处理。

这也可以不使用Keras API实现吗? - Varlor
@Varlor 请尝试将 k.backend.set_session(session) 更改为 with session.as_default(): - EliadL
有没有一种方法可以在TF2中模拟相同的过程? - SimoX

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接