使用TensorFlow Estimator模型(2.0)进行保存、加载和预测

7

是否有关于在TF2中序列化和还原Estimator模型的指南?文档非常零散,并且大部分内容都未更新至TF2。我仍然没有看到任何清晰而完整的示例,说明如何将Estimator保存、加载到磁盘并用于对新输入进行预测。

说实话,我有点困惑,因为这似乎很复杂。 Estimators 被标榜为简单、相对高级的拟合标准模型的方式,但在生产环境中使用它们的过程似乎非常神秘。例如,当我通过tf.saved_model.load(export_path)从磁盘加载模型时,我得到一个 AutoTrackable 对象:

<tensorflow.python.training.tracking.tracking.AutoTrackable at 0x7fc42e779f60>

不清楚为什么我没有得到我的Estimator。看起来以前有一个听起来很有用的函数tf.contrib.predictor.from_saved_model,但由于contrib已经消失了,它似乎不再发挥作用(除了TFLite之外)。

任何指针都将非常有帮助。正如您所看到的,我有点失落。


你查看了2.0版本的tf.saved_model.load文档吗?它与1.x版本不同。"从TensorFlow 1.x导入SavedModels"部分适用于估算器,并解释了如何使用“prune”获取可调用来评估模型。 - jdehesa
好的,这可能是混淆的一部分,因为我没有使用TensorFlow 1.x。在那些文档中,没有清晰地阐述如何获取我的Estimator(在我的情况下是DNNRegression)。 - Chris Fonnesbeck
也许我在使用“Estimator”时犯了错误,不适合我所需的功能。也许我应该完全使用Keras接口。我只想生成一些模型,将它们序列化,然后根据需要加载它们以进行预测。 - Chris Fonnesbeck
啊,所以你想要检索实际的评估器对象,而不仅仅是能够加载模型并运行它,对吗?是的,似乎没有专门为此提供API,只有一般用于运行模型或者Keras模型的API... 但总体来说,使用Keras更容易,如果需要评估器接口,则可以从Keras模型创建一个评估器 - jdehesa
“我有点困惑,这似乎很复杂” <-- 实在是太对了。Keras 让所有事情都变得更加困难,而不是更容易。 :) - o-90
1
如果您只是重新创建相同的Estimator(具有相同的model_fn或相同的参数的预制Estimator),并重复使用相同的输出目录,则它将重新加载并准备好使用.predict。如果您不想或没有源代码来重新创建相同的Estimator,则saved_model.load将为您提供一个可以查询的对象(用于预测),但该对象将不具有Estimator API。 - wicke
1个回答

6
也许作者不再需要答案了,但我能够使用 TensorFlow 2.1保存和加载 DNNClassifier。
# training.py
from pathlib import Path
import tensorflow as tf

....
# Creating the estimator
estimator = tf.estimator.DNNClassifier(
    model_dir= < model_dir >,
    hidden_units = [1000, 500],
    feature_columns = feature_columns,  # this is a list defined earlier
    n_classes = 2,
    optimizer = 'adam'
)

feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
servable_model_path = Path(estimator.export_saved_model( < model_dir >, export_input_fn).decode('utf8'))
print(f'Model saved at {servable_model_path}')

对于加载,你找到了正确的方法,你只需要获取predict_fn

# testing.py
import tensorflow as tf
import pandas as pd

def predict_input_fn(test_df):
    '''Convert your dataframe using tf.train.Example() and tf.train.Features()'''
    examples = []
    ....
    return tf.constant(examples)

test_df = pd.read_csv('test.csv', ...)

# Loading the estimator
predict_fn = tf.saved_model.load(<model_dir>).signatures['predict']
# Predict
predictions = predict_fn(examples=predict_input_fn(test_df))

希望这也能帮到其他人(:


1
我已经完全放弃使用Estimators,但这将是其他人有用的参考。我会接受这个答案。 - Chris Fonnesbeck
2
我简直无法相信使用当前的Estimator API保存和加载模型等简单操作是多么麻烦。与之相比,Pytorch如此简单易懂。 - Daniyal Shahrokhian
嗨,@Omar Contugno,你能否详细说明一下如何使用tf.train.Example()tf.train.Features()转换数据框?我在你的回答中遇到了困难。 - Avi Vajpeyi
Scala/Java怎么样? - uuball

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接