如何使用tf.train.MonitoredTrainingSession仅恢复特定的变量

5
如何指定tf.train.MonitoredTrainingSession仅恢复变量子集并对其余变量执行初始化?
从cifar10教程开始.. https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_train.py ..我创建了要恢复和初始化的变量列表,并使用Scaffold指定它们,然后将其传递给MonitoredTrainingSession。
  restoration_saver = Saver(var_list=restore_vars)
  restoration_scaffold = Scaffold(init_op=variables_initializer(init_vars),
                                  ready_op=constant([]),
                                  saver=restoration_saver)

但是这会导致以下错误:
RuntimeError: Init operations did not make model ready for local_init. Init op: group_deps, init fn: None, error: Variables not initialized: conv2a/T, conv2b/T, [...]
在上述错误信息中列出的未初始化变量是我的“init_vars”列表中的变量。
该异常由SessionManager.prepare_session()引发。该方法的源代码似乎表明,如果从检查点恢复了会话,则不会运行init_op。因此,看起来您可以具有恢复变量或已初始化变量,但不能同时具有两者。
5个回答

4

好的,正如我所预料的那样,我通过实现一个新的RefinementSessionManager类,基于现有的tf.training.SessionManager类,获得了我想要的结果。这两个类几乎完全相同,只是我修改了prepare_session方法,使其无论模型是否是从检查点加载都调用init_op。

这样我就可以从检查点中加载变量列表,并在init_op中初始化剩余的变量。

我的prepare_session方法如下:

  def prepare_session(self, master, init_op=None, saver=None,
                  checkpoint_dir=None, wait_for_checkpoint=False,
                  max_wait_secs=7200, config=None, init_feed_dict=None,
                  init_fn=None):

    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
    master,
    saver,
    checkpoint_dir=checkpoint_dir,
    wait_for_checkpoint=wait_for_checkpoint,
    max_wait_secs=max_wait_secs,
    config=config)

    # [removed] if not is_loaded_from_checkpoint:
    # we still want to run any supplied initialization on models that
    # were loaded from checkpoint.

    if not is_loaded_from_checkpoint and init_op is None and not init_fn and self._local_init_op is None:
      raise RuntimeError("Model is not initialized and no init_op or "
                     "init_fn or local_init_op was given")
    if init_op is not None:
      sess.run(init_op, feed_dict=init_feed_dict)
    if init_fn:
      init_fn(sess)

    # [...]

希望这能帮助其他人。

作为对未来旅行者的警告,如果您不提供一个init操作,它似乎会重新初始化您恢复的变量,使您的模型无用。您可以通过提供no-op来解决这个问题。 - Cory Nezin

2

@avital的提示是有效的,为了更加完整:在MonitoredTrainingSession中传递一个脚手架对象,其中包括local_init_opready_for_local_init_op。像这样:

model_ready_for_local_init_op = tf.report_uninitialized_variables(
            var_list=var_list)
model_init_tmp_vars = tf.variables_initializer(var_list)
scaffold = tf.train.Scaffold(saver=model_saver,
               local_init_op = model_init_tmp_vars,
               ready_for_local_init_op = model_ready_for_local_init_op)
with tf.train.MonitoredTrainingSession(...,
                scaffold=scaffold,
                ...) as mon_sess:
   ...

我不明白。如果您将模型变量指定为local_init_op,那么您不是在初始化模型变量两次吗?难道不应该是那些未保存在模型中的其他变量吗? - Magnus
首先,我认为初始化变量两次并不是一个大问题;其次,您可以选择在var_list中包含哪些变量。这个问题是关于恢复一些变量并初始化其他可能已经存在于您的模型中的变量。但也许我没有理解您的问题? - Bastiaan
如果我恢复一个变量然后再初始化它会发生什么?那么恢复的状态不就被覆盖了吗?因为var_list只包括我正在保存的变量,对吧?我的理解是,当var_list中的所有变量都被初始化(从模型中恢复)时,model_ready_for_local_init_op会让我知道这一点(即返回一个空列表),然后我可以继续初始化其他未存储的变量(例如model_init_tmp_vars)。 - Magnus
从 Scaffold 的文档中:ready_for_local_init_op:可选操作,用于验证全局变量是否已初始化并且可以运行local_init_op。当全局变量已初始化时必须返回一个空的1D字符串张量,或者返回列出未初始化全局变量名称的非空1D字符串张量。那么如何将同一var_list用于ready_for_local_init_oplocal_init_op是有意义的呢? - Magnus

1

Scaffold参数包括以下内容:

  • init_op
  • ready_op
  • local_init_op
  • ready_for_local_init_op

init_op仅在从检查点恢复时调用。

if not is_loaded_from_checkpoint:
  if init_op is None and not init_fn and self._local_init_op is None:
    raise RuntimeError("Model is not initialized and no init_op or "
                   "init_fn or local_init_op was given")
  if init_op is not None:
    sess.run(init_op, feed_dict=init_feed_dict)
  if init_fn:
    init_fn(sess)

因此,实际上init_op在这里无法帮助。如果您可以编写新的SessionManager,则可以遵循@user550701的建议。我们也可以使用local_init_op,但在分布式环境中可能有些棘手。

Scaffold将为我们生成默认的init_oplocal_init_op详细信息

  • init_op:将初始化tf.global_variables
  • local_init_op:将初始化tf.local_variables

我们应该初始化变量并同时不破坏默认机制。

单个工作器情况

您可以像这样创建local_init_op

target_collection = [] # Put your target tensors here
collection = tf.local_variables() + target_collection
local_init_op = tf.variables_initializer(collection)
ready_for_local_init_op = tf.report_uninitialized_variables(collection)

分布式情况

我们应该注意避免重复初始化target_collection,因为local_init_op将在多个工作节点上被多次调用。如果变量是局部的,这没有任何区别。但如果它们是全局变量,我们应该确保只初始化一次。为了解决这个问题,我们可以操作collection变量。在主节点上,它包括本地变量和我们的target_collection。而对于非主节点,我们只将本地变量放入其中。

if is_chief:
   collection = tf.local_variables() + target_collection
else:
   collection = tf.local_variables()

总的来说,这有点棘手,但我们不需要入侵tensorflow。

1
你可以使用local_init_op参数解决此问题,该参数在从检查点加载后运行。

0

我曾经遇到过同样的问题,我的解决方案是

checkpoint_restore_dir_for_monitered_session = None
scaffold = None
if params.restore:
    checkpoint_restore_dir_for_monitered_session = checkpoint_save_dir

    restore_exclude_name_list = params.restore_exclude_name_list
    if len(restore_exclude_name_list) != 0:
        variables_to_restore, variables_dont_restore = get_restore_var_list(restore_exclude_name_list)
        saver_for_restore = tf.train.Saver(var_list=variables_to_restore, name='saver_for_restore')
        ready_for_local_init_op = tf.report_uninitialized_variables(variables_to_restore.values())

        local_init_op = tf.group([
            tf.initializers.local_variables(),
            tf.initializers.variables(variables_dont_restore)
            ])

        scaffold = tf.train.Scaffold(saver=saver_for_restore,
                ready_for_local_init_op=ready_for_local_init_op,
                local_init_op=local_init_op)

with tf.train.MonitoredTrainingSession(
        checkpoint_dir=checkpoint_restore_dir_for_monitered_session, 
        save_checkpoint_secs=None,  # don't save ckpt
        hooks=train_hooks,
        config=config,
        scaffold=scaffold,
        summary_dir=params.log_dir) as sess:
    pass

在这段代码片段中,get_restore_var_list 获取 variables_to_restorevariables_dont_restore
saver_for_restore 只会恢复 variables_to_restore 中被 ready_for_local_init_op 检查并通过的变量。
然后运行 local_init_op,它将初始化 local_variables()variables_dont_restore(可能是 tf.variance_scaling_initializer...)。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接