使用tf.function时如何计算梯度

7
我对以下示例中的行为感到困惑:
import tensorflow as tf

@tf.function
def f(a):
    c = a * 2
    b = tf.reduce_sum(c ** 2 + 2 * c)
    return b, c

def fplain(a):
    c = a * 2
    b = tf.reduce_sum(c ** 2 + 2 * c)
    return b, c


a = tf.Variable([[0., 1.], [1., 0.]])

with tf.GradientTape() as tape:
    b, c = f(a)
    
print('tf.function gradient: ', tape.gradient([b], [c]))

# outputs: tf.function gradient:  [None]

with tf.GradientTape() as tape:
    b, c = fplain(a)
    
print('plain gradient: ', tape.gradient([b], [c]))

# outputs: plain gradient:  [<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
# array([[2., 6.],
#        [6., 2.]], dtype=float32)>]

下面的行为是我所期望的。 如何理解@tf.function 的情况?

非常感谢您提前的帮助!

(请注意,此问题与使用 tf.function 时缺少梯度 不同,因为这里所有的计算都在函数内部。)

1个回答

9
梯度记录带不记录在由@tf.function生成的tf.Graph内部的操作,将该函数视为一个整体。大致来说,f应用于a,梯度带记录了对于输入af输出的梯度(只有watched variable,即tape.watched_variables())。
第二种情况中没有生成图形,操作以Eager模式应用。所以一切都按预期进行。
一个好的做法是将计算最昂贵的函数(通常是训练循环)包装在@tf.function中。在您的情况下,应该像这样:
@tf.function
def f(a):
    with tf.GradientTape() as tape:
        c = a * 2
        b = tf.reduce_sum(c ** 2 + 2 * c)
    grads = tape.gradient([b], [c])
    print('tf.function gradient: ', grads)
    return grads

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接