在TensorFlow 2.0中(目前仍为alpha版本),我知道你可以使用修饰符
@tf.function
将普通的Python代码转换为图形。
每次需要这样做时,我是否必须在每个函数的顶部放置@tf.function
?而@tf.function
只考虑以下函数块吗?@tf.function
将普通的Python代码转换为图形。
每次需要这样做时,我是否必须在每个函数的顶部放置@tf.function
?而@tf.function
只考虑以下函数块吗?@tf.function
将Python函数转换为其图形表示。
要遵循的模式是定义训练步骤函数,这是最具计算密集型的函数,并使用@tf.function
装饰它。
通常,代码看起来像:
#model,loss, and optimizer defined previously
@tf.function
def train_step(features, labels):
with tf.GradientTape() as tape:
predictions = model(features)
loss_value = loss(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss_value
for features, labels in dataset:
lv = train_step(features, label)
print("loss: ", lv)
@tf.function装饰器应用于其后紧接着的函数块,而由其调用的任何函数也将在图模式下执行。请参见Effective TF2 guide,其中写道:
在TensorFlow 2.0中,用户应将其代码重构为更小的函数,并根据需要调用它们。通常情况下,不必使用tf.function装饰这些较小的函数;只需使用tf.function来装饰高级计算 - 例如训练的一步,或您的模型的前向传递。