如何在Tensorflow中恢复检查点时获取global_step?

19

我是这样保存我的会话状态的:

self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)
当我恢复时,我希望获取从中恢复的检查点的全局步骤值。这是为了从中设置一些超参数。
通过运行并解析检查点目录中的文件名来完成这项任务是不太正规的。但肯定有更好、内置的方式吧?
8个回答

28

通常的模式是使用一个global_step变量来跟踪步骤。

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

然后您可以通过保存来保存

saver.save(sess, save_path, global_step=global_step)

当您进行恢复时,global_step 的值也会被恢复。


4
每次我恢复训练时,全局步骤变量都会被重置为0,这不起作用。 - Pranay Mathur
这意味着你正在保存的 global_step 到检查点是 0,或者在恢复它后重新初始化为 0。 - Yaroslav Bulatov
这可能是一个不错的解决方案,但如果saver.restore可以返回global_step,那就更简单了。我们只需要执行'global_step=saver.restore(...)'。您认为TensorFlow团队会对这个方向感兴趣吗? - Sung Kim
似乎在某些情况下可能会有用,但也似乎需要大量的工作——现在TF 1.0已经发布,任何对API的更改都必须经过API审查。 - Yaroslav Bulatov
1
@YaroslavBulatov 这对于在此处进行Inception v3训练无效:https://github.com/tensorflow/models/tree/master/inception/inception 在恢复模型后,全局步骤始终为0。 - Visionscaper

7

这是一种比较巧妙的方法,但其他解决方案对我来说都没有起作用。

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 

#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])

更新 2017年9月

我不确定这是否是由于更新而开始工作的,但以下方法似乎在使global_step更新并正确加载方面非常有效:

创建两个操作。一个用于保存global_step,另一个用于增加它:

    global_step = tf.Variable(0, trainable=False, name='global_step')
    increment_global_step = tf.assign_add(global_step,1,
                                            name = 'increment_global_step')

现在,在您的训练循环中,每次运行训练操作时都要运行增量操作。

sess.run([train_op,increment_global_step],feed_dict=feed_dict)

如果您想在任何时候将全局步数值作为整数检索出来,只需在加载模型后使用以下命令:

sess.run(global_step)

这对于创建文件名或计算当前纪元时间而无需使用第二个tensorflow变量来保存该值非常有用。例如,加载时计算当前纪元将会是以下内容:

loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)

1
您可以使用global_step变量来跟踪步骤,但是如果在您的代码中,您正在将此值初始化或分配给另一个step变量,则可能不一致。
例如,您可以使用以下方式定义global_step
global_step = tf.Variable(0, name='global_step', trainable=False)

将其分配给您的培训操作:

train_op = optimizer.minimize(loss, global_step=global_step)

保存在您的检查点中:

saver.save(sess, checkpoint_path, global_step=global_step)

并从您的检查点恢复:

saver.restore(sess, checkpoint_path) 

global_step的值也被恢复了,但是如果您将其分配给另一个变量,比如step,那么您必须这样做:

step = global_step.eval(session=sess)

变量step包含检查点中保存的最后一个global_step
最好也从图形中定义global_step,而不是作为零变量(如先前定义的)。
global_step = tf.train.get_or_create_global_step()

如果存在,则获取最后一个全局步骤 global_step,否则创建一个。

这是我在这个问题上看到的最干净的解决方案!+1。 - Sibbs Gambling

1

我遇到了和Lawrence Du一样的问题,无法在恢复模型时找到获取global_step的方法。所以我应用了他的技巧到我正在使用的Tensorflow/models github仓库中Inception V3训练代码。下面的代码还包含与pretrained_model_checkpoint_path相关的修复。

如果您有更好的解决方案或者知道我漏掉了什么,请留言!

无论如何,这段代码对我有效:

...

# When not restoring start at 0
last_step = 0
if FLAGS.pretrained_model_checkpoint_path:
    # A model consists of three files, use the base name of the model in
    # the checkpoint path. E.g. my-model-path/model.ckpt-291500
    #
    # Because we need to give the base name you can't assert (will always fail)
    # assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)

    variables_to_restore = tf.get_collection(
        slim.variables.VARIABLES_TO_RESTORE)
    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
    print('%s: Pre-trained model restored from %s' %
          (datetime.now(), FLAGS.pretrained_model_checkpoint_path))

    # HACK : global step is not restored for some unknown reason
    last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1])

    # assign to global step
    sess.run(global_step.assign(last_step))

...

for step in range(last_step + 1, FLAGS.max_steps):

  ...

这种方法不适用于官方预训练的inception v3模型检查点(http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz),因为检查点文件名仅包含inception_v3.ckpt。 - Vipin Pillai
在所有预训练模型中,全局步骤都会被重置为零。因此,这些模型可以用于初始化图形以进行微调,而无需设置全局步骤。 - David Boho

1

简短概述

作为tensorflow变量(将在会话中评估)

global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter 
# this variable will be evaluated later in the session
saver = tf.train.Saver()
with tf.Session() as sess:
    # restore all variables from checkpoint
    saver.restore(sess, checkpoint_path)
    # than init table and local variables and start training/evaluation ...

或者:作为numpy整数(不需要任何会话):

reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor('global_step')


长答案

至少有两种方法从检查点中检索全局变量。作为tensorflow变量或numpy整数。如果在Saversave方法中没有提供global_step作为参数,则解析文件名将不起作用。对于预训练模型,请参见答案末尾的备注。

作为Tensorflow变量

如果您需要global_step变量来计算一些超参数,您可以使用tf.train.get_or_create_global_step()。这将返回一个tensorflow变量。因为变量在会话中稍后被评估,所以您只能使用tensorflow操作来计算超参数。例如:max(global_step, 100)不起作用。您必须使用tensorflow等效的tf.maximum(global_step, 100),以便稍后在会话中进行评估。

在会话中,您可以使用saver.restore(sess, checkpoint_path)来使用检查点初始化全局步骤变量。

global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter 
# this variable will be evaluated later in the session
hyper_parameter = tf.maximum(global_step, 100) 
saver = tf.train.Saver()
with tf.Session() as sess:
    # restore all variables from checkpoint
    saver.restore(sess, checkpoint_path)
    # than init table and local variables and start training/evaluation ...

    # for verification you can print the global step and your hyper parameter
    print(sess.run([global_step, hyper_parameter]))

或:作为numpy整数(无会话)

如果您需要全局步骤变量作为标量而不启动会话,则还可以直接从您的检查点文件中读取此变量。 您只需要一个NewCheckpointReader。 由于旧版tensorflow中存在bug,因此应将检查点文件的路径转换为绝对路径。 使用读取器,您可以将模型的所有张量作为numpy变量获取。 全局步骤变量的名称是常量字符串tf.GraphKeys.GLOBAL_STEP,定义为'global_step'

absolute_checkpoint_path = os.path.abspath(checkpoint_path)
reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)

预训练模型备注: 在大多数可在线获取的预训练模型中,全局步骤会被重置为零。因此,这些模型可以用于初始化模型参数以进行微调,而不会覆盖全局步骤。


0

目前的0.10rc0版本似乎有所不同,不再有tf.saver()。现在是tf.train.Saver()。此外,save命令会将信息添加到save_path文件名中以获取global_step,因此我们不能只调用相同的save_path上的restore,因为那不是实际的保存文件。

我现在看到的最简单的方法是使用SessionManager和类似这样的saver:

my_checkpoint_dir = "/tmp/checkpoint_dir"
# make a saver to use with SessionManager for restoring
saver = tf.train.Saver()
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
# use a SessionManager to help with automatic variable restoration
sm = tf.train.SessionManager()
# try to find the latest checkpoint in my_checkpoint_dir, then create a session with that restored
# if no such checkpoint, then call the init_op after creating a new session
sess = sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=my_checkpoint_dir))

就是这样。现在你有了一个会话,它要么从my_checkpoint_dir中恢复(在调用之前确保该目录存在),要么如果没有检查点,则创建一个新的会话并调用init_op来初始化变量。

当你想保存时,只需将其保存到该目录中任何你想要的名称,并传递global_step。以下是一个示例,在循环中保存步骤变量作为全局步骤,因此如果你杀死程序并重新启动它以恢复检查点,则返回到该点:

checkpoint_path = os.path.join(my_checkpoint_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)

这将在my_checkpoint_dir中创建文件,例如"model.ckpt-1000",其中1000是传递的global_step。如果程序继续运行,则会得到更多类似于"model.ckpt-2000"的文件。上面的SessionManager会在程序重新启动时获取最新的这个文件。checkpoint_path可以是任何你想要的文件名,只要它在checkpoint_dir中即可。save()将使用添加了global_step的方式创建该文件(如上所示)。它还会创建一个"checkpoint"索引文件,这就是SessionManager找到最新保存的检查点的方式。


0
一个变量没有按预期恢复的原因很可能是因为它是在你创建tf.Saver()对象之后才被创建的。
当你没有明确指定var_list或将其指定为None时,创建tf.Saver()对象的位置非常重要。许多程序员预期的行为是在调用save()方法时保存图中的所有变量,但事实并非如此,这应该作为文档说明。在对象创建时保存了图中所有变量的快照。
除非您遇到任何性能问题,否则最安全的做法是在决定保存进展时立即创建saver对象。否则,请确保在创建所有变量后创建saver对象。
此外,传递给 saver.save(sess, save_path, global_step=global_step)global_step 仅仅是用于创建文件名的计数器,与是否将其恢复为 global_step 变量无关。在我看来,这是一个参数误称,因为如果您要在每个时期结束时保存进度,最好将时期号作为此参数传递。

0

只是记录一下我的全局步骤保存和恢复的解决方案。

保存:

global_step = tf.Variable(0, trainable=False, name='global_step')
saver.save(sess, model_path + model_name, global_step=_global_step)

恢复:

if os.path.exists(model_path):
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    print("Model restore finished, current globle step: %d" % global_step.eval())

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接