将训练好的Tensorflow模型加载到评估器中

7

假设我已经训练好了一个Tensorflow Estimator:

estimator = tf.contrib.learn.Estimator(
  model_fn=model_fn,
  model_dir=MODEL_DIR,
  config=some_config)

然后我将其适配到一些训练数据中:

estimator.fit(input_fn=input_fn_train, steps=None)

想法是将模型拟合到我的MODEL_DIR。这个文件夹包含一个检查点和几个.meta.index文件。

这非常完美。我想使用我的函数进行一些预测:

estimator = tf.contrib.Estimator(
  model_fn=model_fn,
  model_dir=MODEL_DIR,
  config=some_config)

predictions = estimator.predict(input_fn=input_fn_test)

我的解决方案完美运作,但有一个很大的缺点:你需要知道model_fn,这是我在Python中定义的模型。但如果我在Python代码中添加了一个密集层来改变模型,那么这个模型对于在MODEL_DIR中保存的数据来说是不正确的,导致结果不正确:

NotFoundError (see above for traceback): Key xxxx/dense/kernel not found in checkpoint

我该怎么应对这个问题?我该如何加载我的模型/估算器,以便在一些新数据上进行预测?怎样才能从MODEL_DIR中加载model_fn或estimator?

1个回答

1

避免糟糕的恢复

仅当模型和检查点兼容时,从检查点恢复模型状态才有效。例如,假设您训练了一个包含两个隐藏层的DNNClassifier估计器,每个隐藏层都有10个节点:

classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

在训练完成后(因此,在models/iris中创建检查点之后),想象一下您将每个隐藏层中的神经元数量从10个更改为20个,然后尝试重新训练模型:
classifier2 = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[20, 20],  # Change the number of neurons in the model.
    n_classes=3,
    model_dir='models/iris')

classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

由于检查点中的状态与classifier2中描述的模型不兼容,因此重新训练会失败并出现以下错误:
...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]

为了运行实验,您需要训练和比较模型的不同版本。请保存创建每个model_dir的代码副本,可以通过为每个版本创建单独的git分支来实现。这种分离将保持您的检查点可恢复性。
注:此内容摘自TensorFlow检查点文档。

https://www.tensorflow.org/get_started/checkpoints

希望能够帮到您。

原始模型为什么不以更通用的格式存储?我理解你的解决方案需要为每个模型存储Python代码,但这感觉有点不正规。 - Guido
你的代码构建了模型结构(即每层的层数和神经元数量),如果你的模型结构发生变化,你的备份模型将无法匹配该模型,从而引发错误。因此,每个结构(Python 代码)都将链接一个备份模型。 - Colin Wang

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接