如何在TensorFlow中恢复部分图形?

4

我想在tensorflow中仅恢复计算图的一部分。我的架构包含两个网络。第一个网络的输出是第二个网络的输入。第一个网络是预训练的,我想从检查点中恢复它。同时我不想更新第一个网络的参数。是否有示例可以供我参考以实现这一目标?

谢谢


你的检查点是否对两个网络都进行了训练?还是只针对第二个网络进行了训练? - d_void
不,检查点仅包含第一个网络的权重。我已将第一个网络中的变量更改为trainable=False(在训练它们之后)。我想使用第一个网络的输出来训练第二个网络。 - Sentient07
1个回答

5
我没有你任务所需的确切代码,但是这里有一个简短的指南可能会对你有帮助:
首先,你需要将你的网络解析成 `tf.GraphDef` 格式,代码应该像这样:
graph_def = tf.GraphDef()
with tf.gfile.FastGFile("path/to/graphdef") as f:
  s = f.read()
graph_def.ParseFromString(s)

或者从检查点/保存的模型中恢复,然后通过以下方式转换为GraphDef
tf.train.import_meta_graph('checkpoint.meta')
tf.get_default_graph().as_graph_def()

现在您拥有了 graph_def。
其次,使用 tf.graph_util.extract_sub_graph 从 graph_def 中提取子图,您可以指定目标节点作为第二个网络的输入。
最后,使用 tf.import_graph_def 导入第二步中的子图。
另外,由于您不想更新第一个网络的参数,您可以使用 tf.graph_util.convert_variables_to_constants 冻结其参数。

1
@Sentient07 在你的情况下,这是足够的,但是一个具有 trainable=False 的变量只意味着它的值不能通过反向传播进行更改,可以通过 variable.assign(some_new_tensor) 进行更改,但是如果将此变量转换为常量,则无法更改其值。在使用 tf.image_graph_def 导入到Python控制台之前,您可以在任何地方使用 convert_variables_to_constants - Jie.Zhou
你如何保存你的网络,是不是一个 .pb 文件?"path/to/graphdef" 是这个文件的路径。 - Jie.Zhou
我已经使用ckpt创建了一个检查点,同时我也有.meta文件。 - Sentient07
meta_graph_def = tf.MetaGraphDef() with tf.gfile.FastGFile("path/to/graphdef") as f: s = f.read() meta_graph_def.ParseFromString(s) graph_def = meta_graph_def.graph_def - Jie.Zhou
你好,非常感谢你的回答,周杰。它完美地解决了我的问题。现在我想让那个有检查点的图形也可以被训练。我该怎么做呢? - Sentient07
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接