我正在学习如何自定义训练循环的教程。
最后一个例子展示了使用自定义训练实现的GAN,其中只定义了__init__
、train_step
和compile
方法。
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, loss_fn):
super(GAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
...
如果我的模型也有一个名为
call()
的自定义函数,那会发生什么?train_step()
是否会覆盖call()
?fit()
调用call()
和train_step()
吗?二者之间的区别是什么?下面是另一段我写的代码,我想知道
fit()
调用的是call()
还是train_step()
:class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True,
reset_after=True
)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
@tf.function
def train_step(self, inputs):
# unpack the data
inputs, labels = inputs
with tf.GradientTape() as tape:
predictions = self(inputs, training=True) # forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss=self.compiled_loss(labels, predictions, regularization_losses=self.losses)
# compute the gradients
grads=tape.gradient(loss, model.trainable_variables)
# Update weights
self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(labels, predictions)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}