Keras回调函数EarlyStopping对比训练损失和验证损失

6
我将为您翻译以下内容:

我正在使用PythonKeras来拟合神经网络。

为了避免过拟合,我希望监控训练/验证损失并创建适当的回调函数,当训练损失远小于验证损失时停止计算。

回调函数的示例:

callback = [EarlyStopping(monitor='val_loss', value=45, verbose=0, mode='auto')]

有没有办法在训练损失相对于验证损失过小时停止训练?

提前感谢您。

1个回答

5
您可以创建一个自定义的回调类来满足您的需求。
我已经创建了一个应该符合您需要的类:
class CustomEarlyStopping(Callback):
    def __init__(self, ratio=0.0,
                 patience=0, verbose=0):
        super(EarlyStopping, self).__init__()

        self.ratio = ratio
        self.patience = patience
        self.verbose = verbose
        self.wait = 0
        self.stopped_epoch = 0
        self.monitor_op = np.greater

    def on_train_begin(self, logs=None):
        self.wait = 0  # Allow instances to be re-used

    def on_epoch_end(self, epoch, logs=None):
        current_val = logs.get('val_loss')
        current_train = logs.get('loss')
        if current_val is None:
            warnings.warn('Early stopping requires %s available!' %
                          (self.monitor), RuntimeWarning)

        # If ratio current_loss / current_val_loss > self.ratio
        if self.monitor_op(np.divide(current_train,current_val),self.ratio):
            self.wait = 0
        else:
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
            self.wait += 1

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0 and self.verbose > 0:
            print('Epoch %05d: early stopping' % (self.stopped_epoch))

我认为你想要在train_loss和validation_loss之间的比例达到一定阈值时停止。这个比值应该介于0.0和1.0之间,但是使用1.0会有风险,因为训练初期验证损失和训练损失可能会剧烈波动。你可以设置一个耐心参数,等待一定数量的周期来确定是否达到了您设定的阈值。例如,使用方式如下:
callbacks = [CustomEarlyStopping(ratio=0.5, patience=2, verbose=1), 
            ... Other callbacks ...]
...
model.fit(..., callbacks=callbacks)

在这种情况下,如果训练损失低于0.5*val_loss连续2个 epoch,程序将会停止运行。

这是否对您有帮助?


非常欢迎 :) 我刚刚拿了EarlyStopping类的源代码并进行了适应... 随意编辑以满足您的需求,那里没有什么魔法! - Nassim Ben
谢谢您提供的示例,但是在使用tensorflow.keras和py3.5时,我遇到了错误“TypeError: super(type, obj): obj must be an instance or subtype of type”。 - Austin
我认为这是一个打字错误,应该是super(CustomEarlyStopping, self).__init__() - Austin

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接