Tensorflow:如何将.meta、.data和.index模型文件转换为一个graph.pb文件

38
在TensorFlow中,从头开始训练会生成以下6个文件:
  1. events.out.tfevents.1503494436.06L7-BRM738
  2. model.ckpt-22480.meta
  3. checkpoint
  4. model.ckpt-22480.data-00000-of-00001
  5. model.ckpt-22480.index
  6. graph.pbtxt
我想将它们(或所需的文件)转换为一个文件graph.pb,以便能够将其传输到我的Android应用程序。我尝试了脚本freeze_graph.py,但它需要input.pb文件作为输入,而我没有这个文件。(我只有前面提到的这6个文件)。如何才能得到这个freezed_graph.pb文件?我看到过几个线程,但都对我没用。

1
请参考此处:https://stackoverflow.com/questions/45433231/freezing-a-cnn-tensorflow-model-into-a-pb-file/45437684#45437684 - Dat Tran
你是如何获得 graph.pbtxt 的?如果这是你模型的图形,则可以使用 freeze.py 将其冻结为 .pbtxt - velikodniy
在完成训练后,我在训练日志中找到了graph.pbtxt文件。然而,在训练完成之前就已经保存了它。请在图形的先前保存状态中检查它。对于从头开始的训练,我使用了train_image_classifier.py脚本。在训练过程中,我使用了自己的图片(.jpg),但在使用build_image_data.py脚本之前,我需要将其转换为.tfrecord文件。 - Rafal
4个回答

46

您可以使用这个简单的脚本来实现。但是您必须指定输出节点的名称。

import tensorflow as tf

meta_path = 'model.ckpt-22480.meta' # Your .meta file
output_node_names = ['output:0']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint('path/of/your/.meta/file'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

如果您不知道输出节点的名称,有两种方法:

  1. 您可以使用 Netron 或控制台summarize_graph 工具浏览图形并找到名称。

  2. 您可以将所有节点用作以下示例中的输出节点。

output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]

(请注意,在调用convert_variables_to_constants之前必须放置此行。)

但我认为这是不寻常的情况,因为如果您不知道输出节点,则实际上无法使用图形。


8
有没有简单的方法获取输出节点的名称? - Michael Ramos
2
我遇到了这个错误,可能是因为我不确定我的output_node_names是否正确。File "/path/to/saver.py", line 1796, in restore raise ValueError("Can't load save_path when it is None.") - craq
1
@craq 看起来在当前目录(路径名为“.”)中找不到任何检查点。尝试显式设置检查点的路径:saver.restore(sess, 'path/to/model.ckpt') - velikodniy
3
如果有人遇到和我一样的问题,尝试冻结图表时出现“Attempting to use uninitialized value”的错误,请在加载权重后添加以下代码: init=tf.global_variables_initializer() sess.run(init) - Eek
1
通常不是所有可训练的变量都是您所需的输出节点。此外,输出节点可能根本不是一个变量。例如,在表达式 a * x + 1 的图形中,输出节点是 add - velikodniy
显示剩余10条评论

7

作为对其他人有帮助的提示,我在github上回答后也在这里回答;-)。我认为你可以尝试像这样做(使用tensorflow/python/tools中的freeze_graph脚本):

python freeze_graph.py --input_graph=/path/to/graph.pbtxt --input_checkpoint=/path/to/model.ckpt-22480 --input_binary=false --output_graph=/path/to/frozen_graph.pb --output_node_names="the nodes that you want to output e.g. InceptionV3/Predictions/Reshape_1 for Inception V3 "

这里的重要标志是--input_binary=false,因为graph.pbtxt文件是文本格式。我认为它对应于所需的二进制格式的等效文件graph.pb。

关于output_node_names,对我来说真的很困惑,因为我在这个部分仍然有一些问题,但是您可以在tensorflow中使用summarize_graph脚本,它可以将pb或pbtxt作为输入。

祝好,

Steph


在 ssd_mobilnet_v1_coco 中,我应该使用什么来替换 -out_node_name? - abhimanyuaryan
@PratikKhadloya,你能回答我上面的评论吗? - abhimanyuaryan
示例用法:python freeze_graph.py --input_graph=some_graph_def.pb --input_checkpoint=model.ckpt-8361242 --output_graph=/tmp/frozen_graph.pb --output_node_names=softmax - Pratik Khadloya

3

我试过freezed_graph.py脚本,但是output_node_name参数让我完全搞不清楚。任务失败了。

于是我试了另外一个:export_inference_graph.py。结果符合预期!

python -u /tfPath/models/object_detection/export_inference_graph.py \
  --input_type=image_tensor \
  --pipeline_config_path=/your/config/path/ssd_mobilenet_v1_pets.config \
  --trained_checkpoint_prefix=/your/checkpoint/path/model.ckpt-50000 \
  --output_directory=/output/path

我使用的TensorFlow安装包来自于这里:https://github.com/tensorflow/models

嗨@kennynut,--pipeline_config_path是什么? 在这种文件中写了什么内容,你能给我一个例子吗? 我已经使用Tensorflow有一段时间了,但从未需要使用这样的配置文件。 - Scott Yang
pipeline_config_path 提供了冻结图正常运行所需的基本配置。通常情况下,它带有默认名称 pipeline.config,并位于来自 Google Git Hub 存储库中模型动物园压缩包的默认根路径下。 - Neveroldmilk

1

首先,使用以下代码生成graph.pb文件。 with tf.Session() as sess:

    # Restore the graph
    _ = tf.train.import_meta_graph(args.input)

    # save graph file
    g = sess.graph
    gdef = g.as_graph_def()
    tf.train.write_graph(gdef, ".", args.output, True)

然后,使用summarize graph获取输出节点名称。最后,使用。
python freeze_graph.py --input_graph=/path/to/graph.pbtxt --input_checkpoint=/path/to/model.ckpt-22480 --input_binary=false --output_graph=/path/to/frozen_graph.pb --output_node_names="the nodes that you want to output e.g. InceptionV3/Predictions/Reshape_1 for Inception V3 "

生成冻结图。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接