如何使用tf.estimator(无论是预测还是评估方法)获取返回的预测和标签?

8

我正在使用Tensorflow 1.4。

我创建了一个自定义的tf.estimator来进行分类,就像这样:

def model_fn():
    # Some operations here
    [...]

    return tf.estimator.EstimatorSpec(mode=mode,
                           predictions={"Preds": predictions},
                           loss=cost,
                           train_op=loss,
                           eval_metric_ops=eval_metric_ops,
                           training_hooks=[summary_hook])

my_estimator = tf.estimator.Estimator(model_fn=model_fn, 
                       params=model_params,
                       model_dir='/my/directory')

我可以轻松地训练它:
input_fn = create_train_input_fn(path=train_files)
my_estimator.train(input_fn=input_fn)

其中,input_fn 是一个从tfrecords文件中读取数据的函数,使用 tf.data.Dataset API

由于我是从 tfrecords 文件中读取数据,因此在进行预测时,我没有内存中的标签

我的问题是,如何通过predict()方法或evaluate()方法返回预测结果和标签?

似乎没有办法两者兼得。 predict()没有访问标签的权限,而使用evaluate()方法无法访问predictions字典。


正如你所说,predict中没有标签(因为那是用于推理的,也就是用于分类新数据)。问题在于evaluate调用不会返回标签,因为它运行循环遍历整个数据集并计算汇总指标,然后返回这些指标。 如果您想要每个batch的预测和标签,您需要从checkpoint中加载模型,并使用tf.Session()循环执行sess.run([predictions, labels])直到数据用尽。 - GPhilo
看起来很傻,但我该如何以这种方式检索标签?我可以在检查点文件中添加logits(例如使用tf.add_to_collection),但无法使用标签。 - Benjamin Larrousse
@GPhilo,你有什么想法吗?我错过了什么吗? - Benjamin Larrousse
我不确定我理解你想要做什么。 - GPhilo
在使用tf.Estimator训练模型之后,我想导出两个列表,一个包含标签,一个包含预测值,以便进行分析(例如校准曲线)。但是正如我们所说的那样,我必须创建一个tf.Session(),但我无法执行sess.run([predictions, labels]),因为标签是使用tf.data.Dataset API从tfrecords动态读取的,并且似乎我不能保存一个Tensor,它保存这些标签值并通过我的检查点检索。 - Benjamin Larrousse
显示剩余2条评论
1个回答

11

完成训练后,在'/my/directory'中您会有一堆检查点文件。

您需要重新设置输入流水线,手动加载其中一个文件,然后开始循环遍历批次并存储预测和标签:

# Rebuild the input pipeline
input_fn = create_eval_input_fn(path=eval_files)
features, labels = input_fn()

# Rebuild the model
predictions = model_fn(features, labels, tf.estimator.ModeKeys.EVAL).predictions

# Manually load the latest checkpoint
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('/my/directory')
    saver.restore(sess, ckpt.model_checkpoint_path)

    # Loop through the batches and store predictions and labels
    prediction_values = []
    label_values = []
    while True:
        try:
            preds, lbls = sess.run([predictions, labels])
            prediction_values += preds
            label_values += lbls
        except tf.errors.OutOfRangeError:
            break
    # store prediction_values and label_values somewhere

更新:更改为直接使用您已经拥有的model_fn函数。


1
看看我的更新代码,你可以通过EstimatorSpecpredictions属性简单地获取预测结果。 - GPhilo
关于为什么分类错误,这是因为您正在初始化所有变量。这意味着清除已训练的权重并将它们重新初始化为随机值。当然,这会导致预测出现错误。 - GPhilo
@GPhilo,你能否看一下这个问题:https://dev59.com/Yajka4cB1Zd3GeqPFeLH。这是我用来构建模型的代码 - Effective_cellist
打开一个新问题,将所有必要的信息和您的代码放在其中,然后在此处链接它,我会查看。 - GPhilo
@GPhilo 我已经在这里开了一个问题(https://dev59.com/Yajka4cB1Zd3GeqPFeLH)。请看一下。感谢你的时间。 - Effective_cellist
显示剩余5条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接