如何在Tensorflow API中运行多个图表?

17

Tensorflow API提供了一些预训练模型,并允许我们使用任何数据集进行训练。

我想知道如何在一个tensorflow会话中初始化和使用多个图。我想要导入两个已训练好的模型到两个不同的图中,并将它们用于物体检测,但我迷失在尝试在一个会话中运行多个图的过程中。

有没有特别的方法来在一个会话中处理多个图?

另一个问题是,即使我为两个不同的图创建了两个不同的会话并尝试与它们一起工作,我最终得到的第二个会话的结果与第一个实例化的会话相似。

3个回答

13

每个 Session 只能拥有一个 Graph。不过,根据你特定的需求,你有几种选择。

第一种选择是创建两个分别加载一个图的独立会话。你提到使用这种方法,每个会话得到的结果都出现了意外的相似性,但是没有更多细节信息很难判断具体问题出在哪里。我猜测可能是同样的图被加载到了每个会话中,或者你试图单独运行每个会话时运行了同一个会话,但是没有更多细节信息很难确定。

第二种选择是将两个图作为主会话图的子图加载。你可以在图中创建两个作用域,并在每个要加载的图中构建该图的结构。然后你可以像处理独立图一样处理它们,因为它们之间没有连接。当运行正常的图全局函数时,需要指定这些函数应用于哪个作用域。例如,在使用一个子图的优化器更新该子图时,你需要使用类似这个答案中所示的方式只获取该子图作用域下的可训练变量。

除非你明确需要这两个图在 TensorFlow 图中相互交互,否则我建议选择第一种方法,这样你就不需要额外处理子图所需的操作(如需要在任何给定时刻过滤哪个作用域正在工作,以及可能共享于两个图之间的全局事务)。


谢谢回复。我不了解第二个选项的性能,但创建两个会话可能会对CPU / GPU造成很大的负载,然后我们可能无法实时使用它们。您认为采用第二个选项对CPU的影响是否相似或更小? 我将尽快尝试为您提供有关创建不同会话问题的更多详细信息。 - saikishor
除非两个加载的图共享变量,否则我认为拥有两个会话不会导致比将它们都加载到单个会话中更大的CPU/GPU使用量。在大多数情况下,会话的开销应该与图形元素本身的内存使用相比非常小。 - golmschenk
请在此处找到代码:https://pastebin.com/VnN8f8FC。如果可以,请尝试对此进行评论。谢谢。 - saikishor
@SaiKishorKothakota:我不确定您在代码的其他部分可能还要使用这些图形做什么,但看起来您只加载了“GraphDef”。它通常只包含图形结构,而不包含训练好的权重。如果您正在使用预训练模型,则还需要从检查点文件中加载权重。如果您正在训练模型,请确保您不会在每个训练步骤中使用这些函数之一重新创建图形。 - golmschenk
我正在加载图形权重,这就是它在这里的实现方式:https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb。类似的,我只是使用函数,我发现我可以修改函数之间的函数调用,第二个被调用的函数的输出与第一个函数相似,无论我在调用中提供的会话信息如何。如果权重没有加载,显然我将无法获得结果。 - saikishor

5
我曾面临相同的挑战,在数月的研究后终于解决了这个问题。我使用了 tf.graph_util.import_graph_def。根据文档

name:(可选)将添加到图定义中名称前面的前缀。请注意,这不适用于导入的函数名称。默认为“import”。

因此,通过添加此前缀,可以区分不同的会话。
例如:
first_graph_def = tf.compat.v1.GraphDef()
second_graph_def = tf.compat.v1.GraphDef()

# Import the TF graph : first
first_file = tf.io.gfile.GFile(first_MODEL_FILENAME, 'rb')
first_graph_def.ParseFromString(first_file.read())
first_graph = tf.import_graph_def(first_graph_def, name='first')

# Import the TF graph : second
second_file = tf.io.gfile.GFile(second_MODEL_FILENAME, 'rb')
second_graph_def.ParseFromString(second_file.read())
second_graph = tf.import_graph_def(second_graph_def, name='second')

# These names are part of the model and cannot be changed.
first_output_layer = 'first/loss:0'
first_input_node = 'first/Placeholder:0'

second_output_layer = 'second/loss:0'
second_input_node = 'second/Placeholder:0'

# initialize probability tensor
first_sess = tf.compat.v1.Session(graph=first_graph)
first_prob_tensor = first_sess.graph.get_tensor_by_name(first_output_layer)

second_sess = tf.compat.v1.Session(graph=second_graph)
second_prob_tensor = second_sess.graph.get_tensor_by_name(second_output_layer)

first_predictions, = first_sess.run(
        first_prob_tensor, {first_input_node: [adapted_image]})
    first_highest_probability_index = np.argmax(first_predictions)

second_predictions, = second_sess.run(
        second_prob_tensor, {second_input_node: [adapted_image]})
    second_highest_probability_index = np.argmax(second_predictions)

如您所见,现在您可以在一个tensorflow会话中初始化并使用多个图表。

希望这对您有帮助。


1

在一个会话中,参数“arg”应该是None或者是图形的实例。

这里是源代码

class BaseSession(SessionInterface):
  """A class for interacting with a TensorFlow computation.
  The BaseSession enables incremental graph building with inline
  execution of Operations and evaluation of Tensors.
  """

  def __init__(self, target='', graph=None, config=None):
    """Constructs a new TensorFlow session.
    Args:
      target: (Optional) The TensorFlow execution engine to connect to.
      graph: (Optional) The graph to be used. If this argument is None,
        the default graph will be used.
      config: (Optional) ConfigProto proto used to configure the session.
    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        creating the TensorFlow session.
      TypeError: If one of the arguments has the wrong type.
    """
    if graph is None:
      self._graph = ops.get_default_graph()
    else:
      if not isinstance(graph, ops.Graph):
        raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))

从下面的代码片段中可以看出,它不可能是一个列表。

if graph is None:
  self._graph = ops.get_default_graph()
else:
  if not isinstance(graph, ops.Graph):
    raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))

ops.Graph(通过help(ops.Graph)查找)对象中可以看出,它不能是多个图形。

有关会话和图形的更多信息

If no `graph` argument is specified when constructing the session,
the default graph will be launched in the session. If you are
using more than one graph (created with `tf.Graph()` in the same
process, you will have to use different sessions for each graph,
but each graph can be used in multiple sessions. In this case, it
is often clearer to pass the graph to be launched explicitly to
the session constructor.

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接