使用XLA tf.function运行TensorFlow时出现错误

3
当我尝试编译这段代码时,会出现以下错误。
File "xla_test.py", line 25, in <module>
    @tf.function(jit_compile=True)
TypeError: function() got an unexpected keyword argument 'jit_compile'
2个回答

2
不需要切换到tf-nightly,只需使用以下代码:

@tf.function(experimental_compile=True)

来自tensorflow文档

experimental_compile 如果为True,则函数始终由XLA编译。在某些情况下(例如TPU、XLA_GPU、密集张量计算),XLA可能更有效率。

在我的情况下,MCMC采样没有使用该参数:约1分37秒,使用experimental_compile=True:约6秒。Tensorflow是从源代码构建的(r2.4分支)。


0
安装 tf-nightly 解决了这个问题。
pip install tf-nightly

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接