在Tensorflow Keras中创建自定义度量类

3
我希望构建一个指标,用于计算组级别的精确度。
例如,假设LSTM的输出形状为(batch, 10, 1),我想沿着时间维度(按10个时间戳分组),并计算精确度。
我创建了这个指标,它继承了精确度:
class PrecisionGrouped(tf.keras.metrics.Precision):
  def __init__(self,
               thresholds=None,
               top_k=None,
               class_id=None,
               name=None,
               dtype=None):
    super(PrecisionGrouped, self).__init__(name=name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.math.reduce_max(y_true, axis=1)
    y_pred = tf.math.reduce_max(y_pred, axis=1)
    return super().update_state(y_true, y_pred, sample_weight)

然而,当我运行代码时,它抱怨update_state方法应该返回张量。 但是我只是调用了父类的方法,它也返回了update_op。
TypeError: To be compatible with tf.contrib.eager.defun, Python functions must return zero or more Tensors; in compilation of <function PrecisionGrouped.update_state at 0x1a384b1400>, found return value of type <class 'tensorflow.python.framework.ops.Operation'>, which is not a Tensor.

我所做的就是给输入添加一个简单的预处理步骤。tf.keras.metrics.Precision()按预期工作,但PrecisionGrouped()不行。
1个回答

4

您可以从update_state()方法中删除换行符。

def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.math.reduce_max(y_true, axis=1)
    y_pred = tf.math.reduce_max(y_pred, axis=1)
    super().update_state(y_true, y_pred, sample_weight)

你可以从自定义指标中删除返回语句和组操作,这不是必需的。由于TPU存在问题,内置指标具有不同的要求。一旦问题得到解决,我们也会从内置指标的update_state中删除返回语句。更多详细信息,请参阅GitHub issue

谢谢,这很有帮助。我添加的唯一一件事是 sample_weight=1,因为它必须与 y_true 广播。 - GRS

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接