如何在Tensorflow中暂停/恢复训练

11

这个问题是在文档中提供保存和恢复的信息之前提出的。目前我认为这个问题已经过时了,建议大家参考Save and Restore官方文档。

旧问题的要点:

我对CIFAR教程使用TF没有问题。我改变了代码,将train_dir(具有检查点和模型的目录)保存到已知位置。

这就带来了我的问题:如何使用TF暂停和恢复训练?

5个回答

14

TensorFlow使用类似于图的计算,节点(Ops)和边缘(变量,也称为状态),并且它为其Vars提供了Saver

因此,由于它是分布式计算,您可以在一个机器/处理器上运行图的一部分,在另一个机器上运行其余部分,同时保存状态(Vars)并在下次继续工作时使用。

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

稍后您可以使用

tf.train.Saver.restore(sess, save_path)

恢复您保存的变量。

Saver用法


使用tf.train.Saver.restore(sess, save_path)命令会导致错误,因为restore方法需要一个Saver实例。 - Kongsea

2

如Hamed所述,使用tensorflow的正确方法是:

Original Answer翻译成"最初的回答"

    saver=tf.train.Saver()
    save_path='checkpoints/'
    -----> while training you can store using
    saver.save(sess=session,save_path=save_path)
    -----> and restore
    saver.restore(sess=session,save_path=save_path)

这将加载您最后保存的模型,并仅从那里开始训练(如果需要)。"Original Answer"翻译成"最初的回答"。

2
使用tf.train.MonitoredTrainingSession()有助于在我的机器重新启动时恢复训练。
需要注意以下几点:
1. 确保保存检查点。在tf.train.saver()中,您可以指定max_checkpoints以保留检查点。 2. 在tf.train.MonitoredTrainingSession(checkpoint='dir_path',save_checkpoint_secs=)中指定检查点目录。 根据save_checkpoint_secs参数,上述会话将不断保存和更新检查点。 3. 当您不断保存检查点时,上述功能会寻找最新的检查点,并从那里恢复训练。
"最初的回答"

1

使用Tensorflow 2,您现在可以在调用模型的fit函数时使用BackupAndRestore回调

model.fit(
  train_dataset,
  validation_data=validation_dataset,
  epochs=25,
  callbacks=[
    tf.keras.callbacks.BackupAndRestore(
      # The path where your backups will be saved. Make sure the
      # directory exists prior to invoking `fit`.
      "./training_backup",
      # How often you wish to save a checkpoint. Providing "epoch"
      # saves every epoch, providing integer n will save every n steps.
      save_freq="epoch",
      # Deletes the last checkpoint when saving a new one.
      delete_checkpoint=True,
    )
  ]
)

如果您的fit函数因任何原因退出,只需再次调用它,回调函数将负责加载最新的检查点并继续您的进度。

0

1.打开检查点文件并从中删除不需要的模型。 2.将model_checkpoint_path设置为您要继续的最后一个最佳模型。 文件内容如下:

model_checkpoint_path: "model_gs_043k"
all_model_checkpoint_paths: "model_gs_041k"
all_model_checkpoint_paths: "model_gs_042k"
all_model_checkpoint_paths: "model_gs_043k"

这里,它继续使用 model_gs_043k

3.删除文件以及事件文件(如果存在),然后您就可以运行培训了。 培训将从模型文件夹中存在的最后一个最佳保存模型开始。如果不存在模型文件,则将从头开始培训。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接