TensorFlow将图形保存到文件/从文件加载图形

108

根据我目前所了解的,有几种不同的方法可以将TensorFlow图形转储到文件中,然后加载到另一个程序中,但我还没有找到清晰的例子/信息说明它们的工作方式。我已经知道的是:

  1. 使用tf.train.Saver()将模型的变量保存到检查点文件(.ckpt)中,并在以后恢复它们(source)。
  2. 将模型保存到.pb文件中,然后使用tf.train.write_graph()tf.import_graph_def()将其加载回来(source)。
  3. 从.pb文件中加载模型,重新训练它,并使用Bazel将其转储到新的.pb文件中(source)。
  4. 冻结图形以将图形和权重一起保存(source)。
  5. 使用as_graph_def()保存模型,并将权重/变量映射为常量(source)。

然而,我仍然有几个问题需要澄清:

  1. 关于检查点文件,它们只保存模型的训练权重吗?检查点文件能够被加载到新程序中,并用于运行该模型吗?还是它们仅仅作为在某个时间/阶段保存模型权重的一种方式?
  2. 关于tf.train.write_graph(),权重/变量也会被保存吗?
  • 关于Bazel,它只能保存到/从.pb文件中进行重新训练吗?是否有一个简单的Bazel命令可以将图形转储到一个.pb文件中?
  • 关于冻结,一个冻结的图可以使用tf.import_graph_def()加载吗?
  • TensorFlow的Android演示从一个.pb文件中加载Google的Inception模型。如果我想用自己的.pb文件替换它,该怎么做?我需要改变任何本地代码/方法吗?
  • 总的来说,所有这些方法之间的区别是什么?或者更广泛地说,as_graph_def()/.ckpt/.pb之间的区别是什么?
  • 简而言之,我正在寻找一种将图形(例如各种操作等)及其权重/变量保存到文件中并加载到另一个程序中供使用(不一定是继续/重新训练)的方法。

    关于这个主题的文档并不是很简单明了,因此任何答案/信息都将非常感激。


    2
    最新/最完整的API是元图,它将为您提供一种同时保存以下三个内容的方法--1)图形2)参数值3)集合:https://www.tensorflow.org/versions/r0.10/how_tos/meta_graph/index.html - Yaroslav Bulatov
    2个回答

    87

    在TensorFlow中保存模型的方法有很多种,这可能会让人有些困惑。针对你的每个子问题:

    1. 检查点文件(例如通过调用saver.save()tf.train.Saver对象上)仅包含权重以及在同一程序中定义的任何其他变量。要在另一个程序中使用它们,必须重新创建相关的图结构(例如通过再次运行构建代码或调用tf.import_graph_def()),这告诉TensorFlow如何处理那些权重。请注意,调用saver.save()还会产生一个包含MetaGraphDef的文件,其中包含图形以及如何将来自检查点的权重与该图形关联的详细信息。有关更多详细信息,请参见教程

    2. tf.train.write_graph()仅写入图形结构;不包括权重。

    3. Bazel与读取或写入TensorFlow图形无关。(也许我误解了你的问题:请随时在评论中澄清。)

    4. 可以使用tf.import_graph_def()加载冻结的图形。在这种情况下,权重通常嵌入在图形中,因此您不需要加载单独的检查点。

  • 主要更改将是更新馈入模型的张量名称以及从模型获取的张量名称。在TensorFlow Android演示中,这将对应于传递给TensorFlowClassifier.initializeTensorFlow()inputNameoutputName字符串。

  • GraphDef是程序结构,通常在训练过程中不会改变。检查点是训练过程状态的快照,通常在训练过程的每个步骤中都会发生变化。因此,TensorFlow为这些数据类型使用不同的存储格式,并且低级API提供了不同的保存和加载方法。高级库,例如MetaGraphDef库、Kerasskflow会在这些机制的基础上构建提供更方便的方式来保存和恢复整个模型。


  • 这是否意味着C++ API文档是错误的,当它说你可以加载使用tf.train.write_graph()保存的图并执行它吗? - mnicky
    2
    C++ API文档并没有说谎,但是它缺少一些细节。最重要的细节是,在执行图形时,除了由tf.train.write_graph()保存的GraphDef之外,您还需要记住要提供和获取的张量的名称(上面的第5项)。 - mrry
    @mrry:我尝试使用TensorFlow的DeepDream示例,但似乎需要以pb格式预训练模型!我运行了Cifar10示例,但它只创建检查点!我找不到任何pb文件或其他文件!如何将我的检查点转换为DeepDream示例使用的pb格式? - Hossein
    2
    @Coderx7 我认为你不能将 .ckpt 转换成 .pb,因为 checkpoint 只包含权重和变量,并不知道图的结构。 - David Ortiz
    2
    有没有一个简单的代码可以加载一个.pb文件并运行它? - Kong
    显示剩余3条评论

    1
    你可以尝试下面的代码:

    with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        g_in = tf.import_graph_def(graph_def, name="")
    sess = tf.Session(graph=g_in)
    

    网页内容由stack overflow 提供, 点击上面的
    可以查看英文原文,
    原文链接