Tensorflow 2.0模型使用tf.function非常缓慢,并且每次训练计数更改时都会重新编译。Eager运行速度快约4倍。

17

我有使用未编译的Keras代码构建的模型,并尝试通过自定义训练循环运行它们。

默认情况下,TF 2.0 eager代码在CPU(笔记本电脑)上运行约30秒。当我使用包装tf.function调用方法创建keras模型时,它的运行速度要慢得多,似乎需要很长时间才能启动,特别是“第一次”。

例如,在tf.function代码中,10个样本的初始训练需要40秒,后续训练只需2秒。

在20个样本上,初始训练需要50秒,后续训练需要4秒。

第一次1个样本的训练需要2秒,后续训练需要200毫秒。

因此,每次train的调用似乎都会创建一个新图形,其中复杂性随着train计数而增加!?

我只是做了这样的事:

@tf.function
def train(n=10):
    step = 0
    loss = 0.0
    accuracy = 0.0
    for i in range(n):
        step += 1
        d, dd, l = train_one_step(model, opt, data)
        tf.print(dd)
        with tf.name_scope('train'):
            for k in dd:
                tf.summary.scalar(k, dd[k], step=step)
        if tf.equal(step % 10, 0):
            tf.print(dd)
    d.update(dd)
    return d

模型是keras.model.Model,其具有像示例中那样使用@tf.function修饰的call方法。

1个回答

35

我在这里分析了@tf.function的行为,具体请参见使用Python本地类型

简而言之:设计tf.function不会自动将Python本地类型转换为具有明确定义的dtypetf.Tensor对象。

如果您的函数接受tf.Tensor对象,则在第一次调用时分析该函数,构建图并与该函数相关联。在每个非第一次调用中,如果tf.Tensor对象的dtype匹配,则重用图形。

但是,在使用Python本地类型的情况下,每次使用不同值调用函数时都会构建图形。

简而言之:如果您计划使用@tf.function,请将代码设计为在任何地方都使用tf.Tensor而不是Python变量。

tf.function不是一个神奇的包装器,可以神奇地加速在急切模式下运行良好的函数;它是一个包装器,需要设计急切函数(主体、输入参数、dytpes),以便了解创建图形后会发生什么,以获得真正的加速。


5
这很好...我猜文档中应该有一个很大的警告。如果有的话,我肯定错过了它。 - mathtick
我不理解的一件事是,为什么在这里https://www.tensorflow.org/alpha/guide/effective_tf2中有tf.function的例子,如果这是autograph已知的问题,那么为什么还要在args中包含模型等内容。 - mathtick
4
将一个Keras对象、tf.data.dataset或任何tf.*对象传递并不成问题。性能只有在传递Python本地类型时才会降低。 - nessuno
我很乐意帮忙! - nessuno
2
谢谢,@nessuno。我知道你是在指 numeric 类型,但是我想补充说,即使列表在 Python 中也被认为是本地类型,张量的列表仍然可以正常工作。 - Vivek Subramanian

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接