Model 的 call() 和 train_step() 方法会在何时被调用?

9

我正在学习如何自定义训练循环的教程。

https://colab.research.google.com/github/tensorflow/docs/blob/snapshot-keras/site/en/guide/keras/customizing_what_happens_in_fit.ipynb#scrollTo=46832f2077ac

最后一个例子展示了使用自定义训练实现的GAN,其中只定义了__init__train_stepcompile方法。

class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        ...

如果我的模型也有一个名为call()的自定义函数,那会发生什么?train_step()是否会覆盖call()fit()调用call()train_step()吗?二者之间的区别是什么?
下面是另一段我写的代码,我想知道fit()调用的是call()还是train_step()
class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   return_sequences=True,
                                   return_state=True,
                                   reset_after=True
                                   )
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

  @tf.function
  def train_step(self, inputs):
    # unpack the data
    inputs, labels = inputs
  
    with tf.GradientTape() as tape:
      predictions = self(inputs, training=True) # forward pass
      # Compute the loss value
      # (the loss function is configured in `compile()`)
      loss=self.compiled_loss(labels, predictions, regularization_losses=self.losses)

    # compute the gradients
    grads=tape.gradient(loss, model.trainable_variables)
    # Update weights
    self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(labels, predictions)

    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}
1个回答

15

这些是不同的概念,使用方式如下:

  • train_stepfit调用。基本上,fit在数据集上循环,并将每个批次提供给train_step(当然还处理度量、记录等)。
  • call用于在调用模型时。确切地说,编写model(inputs)或在您的情况下self(inputs)将使用函数__call__,但是Model类已经定义了该函数,以便其将反过来使用call

这些是技术方面的内容,直观方面:

  • call应该定义模型的前向传递。即输入如何转换为输出。
  • train_step定义了训练步骤的逻辑,通常使用梯度下降算法。通常它会使用call,因为训练步骤往往包括模型的前向传递来计算梯度。
关于你提供的 GAN 教程,我认为它实际上是不完整的。它没有定义 call 也可以正常工作,因为自定义的 train_step 显式地调用了生成器/判别器字段(由于这些是预定义模型,所以可以像通常一样调用它们)。如果你尝试像 gan(inputs) 这样调用 GAN 模型,我会认为你会收到错误消息(我没有测试过)。因此,你总是需要调用 gan.generator(inputs) 来生成内容,例如。
最后(这部分可能有点令人困惑),请注意,你可以子类化一个 Model 来定义自定义的训练步骤,但是通过函数式 API(如 model = Model(inputs, outputs))对其进行初始化,这样你就可以在训练步骤中使用 call 而无需自己定义,因为函数式 API 会处理这一点。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接