如何在Tensorflow r12中通过文件名恢复模型?

13

我已经运行了分布式mnist示例:https://github.com/tensorflow/tensorflow/blob/r0.12/tensorflow/tools/dist_test/python/mnist_replica.py

虽然我已经设置了

saver = tf.train.Saver(max_to_keep=0)

在之前的版本中,例如r11,我能够运行每个检查点模型并评估模型的精度。这给了我一个关于精度随全局步数(或迭代次数)变化的进展图。

在r12之前,tensorflow检查点模型保存在两个文件中,model.ckpt-1234model-ckpt-1234.meta。可以通过传递model.ckpt-1234文件名来恢复模型,如下所示saver.restore(sess,'model.ckpt-1234')

但是,我注意到在r12中,现在有三个输出文件model.ckpt-1234.data-00000-of-000001model.ckpt-1234.indexmodel.ckpt-1234.meta

我看到还原文档中说应该给出类似/train/path/model.ckpt的路径来恢复,而不是文件名。有没有办法一次加载一个检查点文件进行评估?我已经尝试了传递model.ckpt-1234.data-00000-of-000001model.ckpt-1234.indexmodel.ckpt-1234.meta等文件,但会出现以下错误:

W tensorflow/core/util/tensor_slice_reader.cc:95] 无法打开 logdir/2016-12-08-13-54/model.ckpt-0.data-00000-of-00001:数据丢失:不是sstable(错误的魔数):也许您的文件处于不同的文件格式中,需要使用不同的恢复运算符?

NotFoundError(请参见上面的回溯):检查点文件中未找到张量名称“hid_b”logdir/2016-12-08-13-54/model.ckpt-0.index [[Node:save / RestoreV2_1 = RestoreV2 [dtypes =[DT_FLOAT],_device =“/ job:localhost / replica:0 / task:0 / cpu:0”](_recv_save / Const_0,save / RestoreV2_1 / tensor_names,save / RestoreV2_1 / shape_and_slices)]]

W tensorflow/core/util/tensor_slice_reader.cc:95] 无法打开 logdir/2016-12-08-13-54/model.ckpt-0.meta:数据丢失:不是sstable(错误的魔数):也许您的文件处于不同的文件格式中,需要使用不同的恢复运算符?

我正在运行安装在 pip 上的 TensorFlow r12 的 OSX Sierra。

任何帮助指导将非常有用。谢谢。

6个回答

8

我也使用过Tensorflow r0.12,我认为保存和恢复模型没有任何问题。以下是一个简单的代码,您可以尝试一下:

import tensorflow as tf

# Create some variables.
v1 = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2 = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.

  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model

虽然在r0.12版本中,检查点被存储在多个文件中,但您可以通过使用通用前缀来恢复它,该通用前缀在您的情况下为“model.ckpt”。


1
我必须添加tf.train.import_meta_graphwith tf.Session() as sess: saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta') saver.restore(sess, "/tmp/model.ckpt") - Alberto Perez
1
我认为你在最后两行的评论应该放在顶部并加粗。这是一眼看到你的答案是否匹配其他人问题的最重要的部分。 - richar8086

5
R12已更改检查点格式。您应该以旧格式保存模型。
import tensorflow as tf
from tensorflow.core.protobuf import saver_pb2
...
saver = tf.train.Saver(write_version = saver_pb2.SaverDef.V1)
saver.save(sess, './model.ckpt', global_step = step)

根据 TensorFlow v0.12.0 RC0 的发行说明

tf.train.Saver 中的新检查点格式已成为默认设置。旧的 V1 检查点仍可读取;由 write_version 参数进行控制,tf.train.Saver 现在默认以新的 V2 格式写出。它显著降低了恢复期间所需的峰值内存和延迟。

详见我的博客

1
“data-0000-of-0001” 这一部分具体是什么意思? - tnq177

4
您可以按照以下方式恢复模型:
saver = tf.train.import_meta_graph('./src/models/20170512-110547/model-20170512-110547.meta')
            saver.restore(sess,'./src/models/20170512-110547/model-20170512-110547.ckpt-250000'))

'/src/models/20170512-110547/' 路径下有三个文件:

model-20170512-110547.meta
model-20170512-110547.ckpt-250000.index
model-20170512-110547.ckpt-250000.data-00000-of-00001

如果一个目录中有多个检查点,例如:路径./20170807-231648/中有检查点文件:

checkpoint     
model-20170807-231648-0.data-00000-of-00001   
model-20170807-231648-0.index    
model-20170807-231648-0.meta   
model-20170807-231648-100000.data-00000-of-00001   
model-20170807-231648-100000.index   
model-20170807-231648-100000.meta

您可以看到有两个检查点,因此您可以使用以下方法:

saver =    tf.train.import_meta_graph('/home/tools/Tools/raoqiang/facenet/models/facenet/20170807-231648/model-20170807-231648-0.meta')

saver.restore(sess,tf.train.latest_checkpoint('/home/tools/Tools/raoqiang/facenet/models/facenet/20170807-231648/'))

1
抱歉,优化器有问题,我发现如果只有一个检查点,你无法直接从目录中恢复模型,但你可以尝试以下方法:saver = tf.train.import_meta_graph('/home/tools/Tools/raoqiang/facenet/src/models/20170512-110547/model-20170512-110547.meta') saver.restore(sess,'/home/tools/Tools/raoqiang/facenet/src/models/20170512-110547/model-20170512-110547.ckpt-250000') - raoqiang

1

好的,我可以回答自己的问题。我发现我的Python脚本在我的路径中添加了一个额外的“/”,因此我执行了以下操作:

saver.restore(sess,'/path/to/train//model.ckpt-1234')

不知何故,这会导致TensorFlow出现问题。

当我将其删除后,调用:

saver.restore(sess,'/path/to/trian/model.ckpt-1234')

它按预期工作。


1

仅使用model.ckpt-1234

至少对我来说有效


0

我是TF的新手,遇到了同样的问题。在阅读了马元的评论后,我将'.index'文件复制到了与'.data-00000-of-00001'文件相同的'train\ckpt'文件夹中。然后它就可以工作了! 因此,在恢复模型时,.index文件就足够了。 我在Win7上使用了r12的TF。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接