在Tensorflow 2.0中自定义训练循环中应用回调函数

25

我正在使用Tensorflow DCGAN实现指南中提供的代码编写自定义训练循环。我想在训练循环中添加回调函数。在Keras中,我知道我们将它们作为参数传递给“fit”方法,但找不到有关如何在自定义训练循环中使用这些回调函数的资源。我正在添加来自Tensorflow文档的自定义训练循环的代码:

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)
6个回答

14

我自己也遇到过这个问题:(1) 我想使用自定义训练循环; (2) 我不想失去Keras提供的回调函数方面的功能; (3) 我不想重新实现它们所有的功能。Tensorflow 的设计哲学是允许开发人员逐渐地选择更低级别的 API。正如 @HyeonPhilYoun 在他下面的评论中指出的那样,tf.keras.callbacks.Callback 的官方文档提供了我们要找的示例。

以下方法对我有效,但可以通过反向工程 tf.keras.Model 进行改进。

关键在于使用 tf.keras.callbacks.CallbackList 然后手动从自定义训练循环中触发其生命周期事件。此示例使用 tqdm 来提供漂亮的进度条,但 CallbackList 有一个 progress_bar 初始化参数,可以让您使用默认值。training_model 是一个典型的 tf.keras.Model 实例。

from tqdm.notebook import tqdm, trange

# Populate with typical keras callbacks
_callbacks = []

callbacks = tf.keras.callbacks.CallbackList(
    _callbacks, add_history=True, model=training_model)

logs = {}
callbacks.on_train_begin(logs=logs)

# Presentation
epochs = trange(
    max_epochs,
    desc="Epoch",
    unit="Epoch",
    postfix="loss = {loss:.4f}, accuracy = {accuracy:.4f}")
epochs.set_postfix(loss=0, accuracy=0)

# Get a stable test set so epoch results are comparable
test_batches = batches(test_x, test_Y)

for epoch in epochs:
    callbacks.on_epoch_begin(epoch, logs=logs)

    # I like to formulate new batches each epoch
    # if there are data augmentation methods in play
    training_batches = batches(x, Y)

    # Presentation
    enumerated_batches = tqdm(
        enumerate(training_batches),
        desc="Batch",
        unit="batch",
        postfix="loss = {loss:.4f}, accuracy = {accuracy:.4f}",
        position=1,
        leave=False)

    for (batch, (x, y)) in enumerated_batches:
        training_model.reset_states()
        
        callbacks.on_batch_begin(batch, logs=logs)
        callbacks.on_train_batch_begin(batch, logs=logs)
        
        logs = training_model.train_on_batch(x=x, y=Y, return_dict=True)

        callbacks.on_train_batch_end(batch, logs=logs)
        callbacks.on_batch_end(batch, logs=logs)

        # Presentation
        enumerated_batches.set_postfix(
            loss=float(logs["loss"]),
            accuracy=float(logs["accuracy"]))

    for (batch, (x, y)) in enumerate(test_batches):
        training_model.reset_states()

        callbacks.on_batch_begin(batch, logs=logs)
        callbacks.on_test_batch_begin(batch, logs=logs)

        logs = training_model.test_on_batch(x=x, y=Y, return_dict=True)

        callbacks.on_test_batch_end(batch, logs=logs)
        callbacks.on_batch_end(batch, logs=logs)

    # Presentation
    epochs.set_postfix(
        loss=float(logs["loss"]),
        accuracy=float(logs["accuracy"]))

    callbacks.on_epoch_end(epoch, logs=logs)

    # NOTE: This is a decent place to check on your early stopping
    # callback.
    # Example: use training_model.stop_training to check for early stopping


callbacks.on_train_end(logs=logs)

# Fetch the history object we normally get from keras.fit
history_object = None
for cb in callbacks:
    if isinstance(cb, tf.keras.callbacks.History):
        history_object = cb
assert history_object is not None

3
谢谢你提供这么详细的回答!对我非常有帮助。很让人烦恼的是,在自定义循环中关于回调函数的官方文档非常少! - emil
1
一份官方文档还特别指出这种方法是最合适的。你可以在示例部分进行查看。https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback - HyeonPhil Youn
1
这应该被接受为官方答案... 做得好! - Rayyan
1
太好了。我已经尝试使用Tensorboard回调并且效果非常好。因此,在我的情况下,它看起来像这样:''' tensorboard_callback = keras.callbacks.TensorBoard( log_dir='./callbacks/tensorboard', histogram_freq=1)_callbacks = [tensorboard_callback] callbacks = keras.callbacks.CallbackList( _callbacks, add_history=True, model=encoder)logs_ae = {} callbacks.on_train_begin(logs=logs_ae)... ... ''' - UlrikP

5

最简单的方法是检查损失函数是否在预期时间段内发生了变化,如果没有,就中断或操纵训练过程。

以下是你可以实现自定义早停回调的一种方式:

def Callback_EarlyStopping(LossList, min_delta=0.1, patience=20):
    #No early stopping for 2*patience epochs 
    if len(LossList)//patience < 2 :
        return False
    #Mean loss for last patience epochs and second-last patience epochs
    mean_previous = np.mean(LossList[::-1][patience:2*patience]) #second-last
    mean_recent = np.mean(LossList[::-1][:patience]) #last
    #you can use relative or absolute change
    delta_abs = np.abs(mean_recent - mean_previous) #abs change
    delta_abs = np.abs(delta_abs / mean_previous)  # relative change
    if delta_abs < min_delta :
        print("*CB_ES* Loss didn't change much from last %d epochs"%(patience))
        print("*CB_ES* Percent change in loss value:", delta_abs*1e2)
        return True
    else:
        return False

这个Callback_EarlyStopping会在每个epoch检查你的指标或损失,通过计算每patience个epoch后的损失移动平均值,判断相对变化是否小于你期望的值,如果是则返回True。你可以利用这个True来中断训练循环。要完全回答你的问题,在你的样本训练循环中,你可以这样使用它:

gen_loss_seq = []
for epoch in range(epochs):
  #in your example, make sure your train_step returns gen_loss
  gen_loss = train_step(dataset) 
  #ideally, you can have a validation_step and get gen_valid_loss
  gen_loss_seq.append(gen_loss)  
  #check every 20 epochs and stop if gen_valid_loss doesn't change by 10%
  stopEarly = Callback_EarlyStopping(gen_loss_seq, min_delta=0.1, patience=20)
  if stopEarly:
    print("Callback_EarlyStopping signal received at epoch= %d/%d"%(epoch,epochs))
    print("Terminating training ")
    break
       

当然,您可以通过多种方式增加复杂性,例如要跟踪哪些损失或指标、在特定时期对损失的兴趣或移动平均损失的兴趣、对值的相对或绝对变化的兴趣等。您可以参考Tensorflow 2.x实现的tf.keras.callbacks.EarlyStopping(此处),该方法通常在流行的tf.keras.Model.fit方法中使用。


5
很不幸,这个答案仅适用于希望使用EarlyStopping回调函数的非常特定情况。但是,还有许多其他有用的回调函数可以被重复使用,而不必从头开始实现。 - Stanley F.

2
aapa3e8的回答是正确的,但我在下面提供了一个更类似于tf.keras.callbacks.EarlyStopping的Callback_EarlyStopping实现。
def Callback_EarlyStopping(MetricList, min_delta=0.1, patience=20, mode='min'):
    #No early stopping for the first patience epochs 
    if len(MetricList) <= patience:
        return False
    
    min_delta = abs(min_delta)
    if mode == 'min':
      min_delta *= -1
    else:
      min_delta *= 1
    
    #last patience epochs 
    last_patience_epochs = [x + min_delta for x in MetricList[::-1][1:patience + 1]]
    current_metric = MetricList[::-1][0]
    
    if mode == 'min':
        if current_metric >= max(last_patience_epochs):
            print(f'Metric did not decrease for the last {patience} epochs.')
            return True
        else:
            return False
    else:
        if current_metric <= min(last_patience_epochs):
            print(f'Metric did not increase for the last {patience} epochs.')
            return True
        else:
            return False

3
很遗憾,这个问题不是关于早停的,而是关于回调函数的一般性问题。为什么这里的每个人都认为问题提出者只想要这个特定的回调函数呢?请翻译此内容。 - Stanley F.

2

我认为您需要手动实现回调功能。这并不太困难。例如,您可以让“train_step”函数返回损失,并在“train”函数中实现早停等回调功能。对于学习率调度等回调,函数tf.keras.backend.set_value(generator_optimizer.lr,new_lr)会很有用。因此,回调的功能将在您的“train”函数中实现。


2

自定义训练循环就是一个普通的Python循环,因此您可以使用if语句在满足某些条件时中断循环。例如:

if len(loss_history) > patience:
    if loss_history.popleft()*delta < min(loss_history):
        print(f'\nEarly stopping. No improvement of more than {delta:.5%} in '
              f'validation loss in the last {patience} epochs.')
        break

如果在过去的patience个epoch中,损失函数没有改善delta%,则循环将被中断。在这里,我使用了一个collections.deque,它可以轻松地用作滚动列表,仅保留最近patience个epoch的信息。
以下是完整的实现,包括来自Tensorflow文档的示例:
patience = 3
delta = 0.001

loss_history = deque(maxlen=patience + 1)

for epoch in range(1, 25 + 1):
    train_loss = tf.metrics.Mean()
    train_acc = tf.metrics.CategoricalAccuracy()
    test_loss = tf.metrics.Mean()
    test_acc = tf.metrics.CategoricalAccuracy()

    for x, y in train:
        loss_value, grads = get_grad(model, x, y)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_loss.update_state(loss_value)
        train_acc.update_state(y, model(x, training=True))

    for x, y in test:
        loss_value, _ = get_grad(model, x, y)
        test_loss.update_state(loss_value)
        test_acc.update_state(y, model(x, training=False))

    print(verbose.format(epoch,
                         train_loss.result(),
                         test_loss.result(),
                         train_acc.result(),
                         test_acc.result()))

    loss_history.append(test_loss.result())

    if len(loss_history) > patience:
        if loss_history.popleft()*delta < min(loss_history):
            print(f'\nEarly stopping. No improvement of more than {delta:.5%} in '
                  f'validation loss in the last {patience} epochs.')
            break

Epoch  1 Loss: 0.191 TLoss: 0.282 Acc: 68.920% TAcc: 89.200%
Epoch  2 Loss: 0.157 TLoss: 0.297 Acc: 70.880% TAcc: 90.000%
Epoch  3 Loss: 0.133 TLoss: 0.318 Acc: 71.560% TAcc: 90.800%
Epoch  4 Loss: 0.117 TLoss: 0.299 Acc: 71.960% TAcc: 90.800%

Early stopping. No improvement of more than 0.10000% in validation loss in the last 3 epochs.

0

我使用了 @Rob Hall 的方法,并且加入了 tensorboard 回调,它确实有效。所以在我的情况下,代码如下:

'''

tensorboard_callback = keras.callbacks.TensorBoard(
    log_dir='./callbacks/tensorboard',
    histogram_freq=1)

_callbacks = [tensorboard_callback]
callbacks = keras.callbacks.CallbackList(
    _callbacks, add_history=True, model=encoder)

    logs_ae = {}
    callbacks.on_train_begin(logs=logs_ae)
...
...

'''


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接