在TensorFlow中,何时使用model.predict(x)和model(x)?

5

我有一个使用tf.keras.models.load_model加载的keras.models.Model模型。

现在有两种使用这个模型的方法。一种是调用model.predict(x),另一种是调用model(x).numpy()。两种方法都会给出相同的结果,但是model.predict(x)需要运行时间长达10倍。

源代码中的注释如下:

计算是分批进行的。此方法专为大规模输入设计,对于适合于一个批次的少量输入,建议直接使用__call__以获得更快的执行速度,例如:model(x)model(x, training=False)

我已经测试了包含1、1,000,000和10,000,000 行数据的x,发现 model(x) 性能更好。

那么,输入需要多大才能被归类为大规模输入?model.predict(x)表现得更好呢?


2
这可能会有所帮助。 - OverLordGoldDragon
1个回答

3
已有一篇现成的stackoverflow回答,你可能会发现有用:https://dev59.com/AVMH5IYBdhLWcg3wxic9#58385156。我是在tensorflow/tensorflow#33340上找到的。该回答建议在model.compile调用中传递experimental_run_tf_function=False,以恢复到TF 1.x版本的模型执行方式。您也可以完全省略model.compile调用(对于预测不必要)。

输入需要多大才能被归类为大规模输入,并使 model.predict(x)执行更好?

这是您可以测试的事情。正如文档所述,如果您的数据适合一个批次,则model(x)可能比model.predict(x)更快。model.predict(x)相比model(x)提供的一个功能是能够预测多个批次。如果您想使用model(x)预测多个批次,您必须自己编写循环。model.predict还提供其他功能,例如回调。
顺便说一下,源代码中的文档是在提交42f469be0f3e8c36624f0b01c571e7ed15f75faf后添加的,这是由于tensorflow/tensorflow#33340而引起的。 model.predict(x)的主要行为在此处实现。它包含的不仅仅是模型的前向传递。这可能解释了一些速度差异。

我已经测试过包含1、1,000,000和10,000,000行的x,但 model(x) 仍然表现得更好。

这个10,000,000行数据是否适合一个批次……?

谢谢您。我想我对所有这些都不够了解,无法真正理解正在发生的事情。批处理是否是我可以适合(GPU)内存的任意数据量? - Gunnarsi
批量大小是模型一次迭代中使用的示例数量。您可以将批量大小设置为适合您GPU的任何大小,但是过大的批量大小可能会导致过度拟合([Keskar et al.,2016](https://arxiv.org/abs/1609.04836))。 - jkr

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接