TensorFlow 2.0:每个函数都需要在顶部加上@tf.function装饰器吗?

24
在TensorFlow 2.0中(目前仍为alpha版本),我知道你可以使用修饰符@tf.function将普通的Python代码转换为图形。 每次需要这样做时,我是否必须在每个函数的顶部放置@tf.function?而@tf.function只考虑以下函数块吗?

4
请看这里:https://github.com/tensorflow/addons/issues/13 - Sharky
2个回答

19

@tf.function将Python函数转换为其图形表示。

要遵循的模式是定义训练步骤函数,这是最具计算密集型的函数,并使用@tf.function装饰它。

通常,代码看起来像:

#model,loss, and optimizer defined previously

@tf.function
def train_step(features, labels):
   with tf.GradientTape() as tape:
        predictions = model(features)
        loss_value = loss(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss_value

for features, labels in dataset:
    lv = train_step(features, label)
    print("loss: ", lv)

17

@tf.function装饰器应用于其后紧接着的函数块,而由其调用的任何函数也将在图模式下执行。请参见Effective TF2 guide,其中写道:

在TensorFlow 2.0中,用户应将其代码重构为更小的函数,并根据需要调用它们。通常情况下,不必使用tf.function装饰这些较小的函数;只需使用tf.function来装饰高级计算 - 例如训练的一步,或您的模型的前向传递。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接