在tf.Estimator设置中使用tf.metrics.precision/recall计算F1分数

3
我正在尝试在tf.Estimator设置中计算F1得分。
我看到了这个SO问题,但无法从中提炼出可行的解决方案。 tf.Estimator的问题在于它希望我提供一个值和一个更新操作,所以现在,我在模型结尾处有以下代码:
if mode == tf.estimator.ModeKeys.EVAL:
    with tf.variable_scope('eval'):
        precision, precision_update_op = tf.metrics.precision(labels=labels,
                                            predictions=predictions['class'],
                                            name='precision')

        recall, recall_update_op = tf.metrics.recall(labels=labels,
                                      predictions=predictions['class'],
                                      name='recall')

        f1_score, f1_update_op = tf.metrics.mean((2 * precision * recall) / (precision + recall), name='f1_score')

        eval_metric_ops = {
            "precision": (precision, precision_update_op),
            "recall": (recall, recall_update_op),
            "f1_score": (f1_score, f1_update_op)}

现在准确率和召回率似乎都运行正常,但是对于 F1 分数,我一直得到 nan

我该如何使其正常运行?

编辑:

可以使用 tf.contrib.metrics.f1_score 来实现一个可行的解决方案,但由于 contrib 在 TF 2.0 中将被弃用,因此我需要一个不依赖 contrib 的解决方案。

4个回答

1
我是这样做的:

def f1_score_class0(labels, predictions):
    """
    To calculate f1-score for the 1st class.
    """
    prec, update_op1 = tf.compat.v1.metrics.precision_at_k(labels, predictions, 1, class_id=0)
    rec,  update_op2 = tf.compat.v1.metrics.recall_at_k(labels, predictions, 1, class_id=0)

    return {
            "f1_Score_for_class0":
                ( 2*(prec * rec) / (prec + rec) , tf.group(update_op1, update_op2) )
    }

0

1) 你为什么要使用tf.metrics.mean?召回率和精确度是标量值。

2) 你尝试过打印f1_scoref1_update_op吗?

3) 从召回率文档中可以看到:

为了在数据流中估计指标,该函数创建一个update_op来更新这些变量并返回召回率。update_op通过权重weights对每个预测进行加权。

由于你直接从处理更新的两个操作中获取F1分数,请尝试使用tf.identity(它实际上不会引起任何更改)。


你不能这样做,因为eval_metric_ops期望得到一个(value, update_op)元组。 - bluesummers
是的,第一个值是您计算的f1_score,第二个操作只是身份操作吗? - IanQ


0

f1值张量可以从精确度和召回率值计算得出。指标必须是(值,update_op)元组。我们可以为f1传递tf.identity。这对我有用:

import tensorflow as tf

def metric_fn(labels, logits):
    predictions = tf.argmax(logits, axis=-1)
    pr, pr_op = tf.metrics.precision(labels, predictions)
    re, re_op = tf.metrics.recall(labels, predictions)
    f1 = (2 * pr * re) / (pr + re)
    return {
        'precision': (pr, pr_op),
        'recall': (re, re_op),
        'f1': (f1, tf.identity(f1))
    }

tf.identity(f1) 的意思是什么,期望的操作是什么? - bluesummers
它将张量转换为虚拟操作,因为tf.metrics必须返回元组(值张量,更新操作)。 - Bohumir Zamecnik
你确定这行代码可行吗?根据我的理解,op(操作符)用于能够批量计算度量指标,但在这里它没有任何意义,我想知道它是否会破坏分数。 - bluesummers

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接