如何根据损失值告诉Keras停止训练?

91

目前我使用以下代码:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

它告诉Keras在2个epochs内损失值没有改进时停止训练。但是我想在损失降至某个常数“THR”后停止训练:

它告诉Keras在2个epochs内损失值没有改善时停止训练。但我希望在损失小于某个常数“THR”之后停止训练:

if val_loss < THR:
    break

我在文档中看到了可以创建自己的回调函数:http://keras.io/callbacks/,但没有找到如何停止训练过程。我需要建议。

7个回答

95

我找到了答案。我查看了Keras源代码,并找到了EarlyStopping的代码。我基于它制作了自己的回调函数:

class EarlyStoppingByLossVal(Callback):
    def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

并使用:

callbacks = [
    EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
    # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

1
如果对某人有用的话,我在这种情况下使用了monitor='loss',效果很好。 - QtRoS
15
看起来Keras已经更新了。EarlyStopping回调函数现在内置了min_delta功能。不再需要篡改源代码,耶!https://dev59.com/01oU5IYBdhLWcg3wnXxO#41459368 - jkdev
4
经过重新阅读问题和答案,我需要更正自己:min_delta 的意思是“如果每个时期(或多个时期)的改进不足,则提前停止。”然而,原帖的问题是如何“在损失降至一定水平时提前停止”。请注意不要改变原意但使翻译更通俗易懂,不提供除翻译外的任何内容。 - jkdev
2
Eliyah,请尝试这个:from keras.callbacks import Callback - ZFTurbo
应该是elif 如果当前值小于self.value: - Cathy
显示剩余2条评论

26

keras.callbacks.EarlyStopping回调函数包含min_delta参数。从Keras文档中得到:

min_delta:被监测量的最小变化量,以符合改进的要求。即,如果绝对变化小于min_delta,则会将其视为没有改进。


3
以下是所需翻译内容:作为参考,这里提供了早期版本Keras(1.1.0)的文档,其中还没有包括min_delta参数:https://faroit.github.io/keras-docs/1.1.0/callbacks/#earlystopping - jkdev
我该如何让它在多个时期中持续,直到“min_delta”存在? - zyxue
EarlyStopping有另一个参数叫做“patience”,表示经过多少个epoch没有改进后就停止训练。 - devin
2
虽然min_delta可能有用,但它并不能完全解决通过绝对值进行早期停止的问题。相反,min_delta作为值之间的差异起作用。 - NeStack

14

一种解决方案是在for循环内调用model.fit(nb_epoch=1, ...),然后可以在for循环中使用break语句,进行任何其他自定义控制流程。


如果他们能够创建一个回调函数,接受一个可以执行该操作的单一函数,那就太好了。 - Honesty

11

我使用自定义回调函数解决了同样的问题。

在下面的自定义回调代码中,将THR赋值为您想要停止训练的值,并将回调添加到您的模型中。

from keras.callbacks import Callback

class stopAtLossValue(Callback):

        def on_batch_end(self, batch, logs={}):
            THR = 0.03 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') <= THR:
                 self.model.stop_training = True

2

在我学习TensorFlow实践专项课程时,我学到了一种非常优雅的技巧。这个技巧与被接受的答案略有不同。

我们以最喜欢的MNIST数据为例。

import tensorflow as tf

class new_callback(tf.keras.callbacks.Callback):
    def epoch_end(self, epoch, logs={}): 
        if(logs.get('accuracy')> 0.90): # select the accuracy
            print("\n !!! 90% accuracy, no further training !!!")
            self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #normalize

callbacks = new_callback()

# model = tf.keras.models.Sequential([# define your model here])

model.compile(optimizer=tf.optimizers.Adam(),
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

因此,在这里我设置了metrics=['accuracy'],因此在回调类中,条件被设置为'accuracy'> 0.90

您可以选择任何指标并像此示例一样监视训练。最重要的是,您可以为不同的指标设置不同的条件并同时使用它们。

希望这有所帮助!


3
函数名称应为 on_epoch_end。 - xarion

0
对我来说,模型只有在我在将stop_training参数设置为True后添加了一个return语句之后才会停止训练,因为我是在self.model.evaluate之后调用的。所以要么确保在函数末尾放置stop_training = True,要么添加一个return语句。
def on_epoch_end(self, batch, logs):
        self.epoch += 1
        self.stoppingCounter += 1
        print('\nstopping counter \n',self.stoppingCounter)

        #Stop training if there hasn't been any improvement in 'Patience' epochs
        if self.stoppingCounter >= self.patience:
            self.model.stop_training = True
            return

        # Test on additional set if there is one
        if self.testingOnAdditionalSet:
            evaluation = self.model.evaluate(self.val2X, self.val2Y, verbose=0)
            self.validationLoss2.append(evaluation[0])
            self.validationAcc2.append(evaluation[1])enter code here

-1

如果您正在使用自定义训练循环,可以使用collections.deque,它是一个“滚动”列表,可以添加元素,当列表的长度大于maxlen时,左侧的项将被弹出。以下是代码行:

loss_history = deque(maxlen=early_stopping + 1)

for epoch in range(epochs):
    fit(epoch)
    loss_history.append(test_loss.result().numpy())
    if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
            break

这是一个完整的示例:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque

data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)

data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))

train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)

model = tf.keras.models.Sequential([
    Dense(8, activation='relu'),
    Dense(16, activation='relu'),
    Dense(info.features['label'].num_classes)])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs, training=True)
        loss = loss_object(labels, logits)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_acc(labels, logits)


@tf.function
def test_step(inputs, labels):
    logits = model(inputs, training=False)
    loss = loss_object(labels, logits)
    test_loss(loss)
    test_acc(labels, logits)


def fit(epoch):
    template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
               'Train Acc {:.2f} Test Acc {:.2f}'

    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for X_train, y_train in train_dataset:
        train_step(X_train, y_train)

    for X_test, y_test in test_dataset:
        test_step(X_test, y_test)

    print(template.format(
        epoch + 1,
        train_loss.result(),
        test_loss.result(),
        train_acc.result(),
        test_acc.result()
    ))


def main(epochs=50, early_stopping=10):
    loss_history = deque(maxlen=early_stopping + 1)

    for epoch in range(epochs):
        fit(epoch)
        loss_history.append(test_loss.result().numpy())
        if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
            print(f'\nEarly stopping. No validation loss '
                  f'improvement in {early_stopping} epochs.')
            break

if __name__ == '__main__':
    main(epochs=250, early_stopping=10)

Epoch  1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch  2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch  3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch  4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch  5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87

Early stopping. No validation loss improvement in 10 epochs.

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接