每10个epoch保存一次模型 tensorflow.keras v2

28
我正在使用在tensorflow v2中作为子模块定义的keras。我使用fit_generator()方法训练我的模型,我希望能够每10个epochs保存一次模型。如何实现这一点?
在Keras中(而不是tf的子模块),我可以使用ModelCheckpoint(model_savepath,period=10)。但是在tf v2中,它们将其更改为ModelCheckpoint(model_savepath, save_freq),其中save_freq可以是'epoch',这种情况下,模型将在每个epoch之后保存。如果save_freq是整数,则在处理完这么多样本后保存模型。但我希望在10个epochs之后保存模型。如何实现这一点?
4个回答

30
使用tf.keras.callbacks.ModelCheckpoint,使用save_freq='epoch'并传递一个额外参数period=10
尽管官方文档没有记录这一点,但这是实现的方式(注意可以传递period,只是没有解释其作用)。

3
我收到以下警告:WARNING:tensorflow:'period' argument is deprecated. Please use 'save_freq' to specify the frequency in number of samples seen. 那么,我猜测这个功能很快就会被淘汰。在这种情况下,我该如何实现它呢? - Nagabhushan S N
2
我相信唯一的替代方案是计算每个时期的示例数量,并将该整数传递给“save_freq”乘以您想要的保存间隔之间的时期数。 - bluesummers
1
@bluesummers “每个时代的示例”这应该是我的批量大小,对吗? - Tom
每个时期的示例是您想要在检查点之间通过网络传递的样本数量 - 这意味着如果您有100个样本(样本!=批处理,批处理是一批样本),并且您放置了400,则会保存每4个时期。 - bluesummers
2
我有和@NagabhushanSN一样的问题。我计算了每个epoch的样本数,以计算我想要保存模型的样本数,但似乎不起作用。批量大小为64,在测试案例中,我每个epoch使用10个步骤。如果我想要每3个epochs保存一次模型,则样本数为64103=1920。我将其用于sav_freq,但输出显示模型在第1个epoch、第2个epoch、第9个epoch、第11个epoch、第14个epoch保存,并且仍在运行。无法理解。period选项似乎可以正常工作,但会弃用。 - beeprogrammer
显示剩余4条评论

5

明确计算每个epoch中批次数量的方法对我效果不错。

BATCH_SIZE = 20
STEPS_PER_EPOCH = train_labels.size / BATCH_SIZE
SAVE_PERIOD = 10

# Create a callback that saves the model's weights every 10 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq= int(SAVE_PERIOD * STEPS_PER_EPOCH))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          batch_size=BATCH_SIZE,
          steps_per_epoch=STEPS_PER_EPOCH,
          epochs=50, 
          callbacks=[cp_callback],
          validation_data=(test_images,test_labels),
          verbose=0)

4

在被接受的答案中提到的参数period现在已不再可用。

使用save_freq参数是一种替代方案,但是有风险,正如文档中所述; 例如,如果数据集大小发生变化,可能会变得不稳定:请注意,如果保存不与时期对齐,则监视指标可能潜在地不太可靠(同样来自文档)。

因此,我使用子类作为解决方案:

class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):

    def __init__(self,
                 filepath,
                 frequency=1,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 options=None,
                 **kwargs):
        super(EpochModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only,
                                                   mode, "epoch", options)
        self.epochs_since_last_save = 0
        self.frequency = frequency

    def on_epoch_end(self, epoch, logs=None):
        self.epochs_since_last_save += 1
        # pylint: disable=protected-access
        if self.epochs_since_last_save % self.frequency == 0:
            self._save_model(epoch=epoch, batch=None, logs=logs)

    def on_train_batch_end(self, batch, logs=None):
        pass

将其用作

callbacks=[
     EpochModelCheckpoint("/your_save_location/epoch{epoch:02d}", frequency=10),
]

请注意,根据您的TF版本,您可能需要更改对超类__init__的调用中的args参数。

0

我也来到这里寻找答案,并想指出一些与之前回答不同的变化。我目前正在使用TF版本2.5.0,period=是有效的,但仅当回调中没有save_freq=时才有效。

my_callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=path
        period=N
    )
]

这对我来说没有问题,即使在回调文档中没有记录句点


谢谢更新。它还没有被移除吗?它已经被标记为弃用了,我想它现在应该已经被移除了。它仍然是弃用的吗? - Nagabhushan S N
3
截至TF Ver 2.5.0版本,该功能仍然存在并正常工作。它仍被标记为过时的,警告信息为:'period'参数已经过时。请使用'save_freq'来指定保存频率,单位是批次数。但它仍按照预期进行保存。 - Andrew Crouch

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接