TensorFlow Keras“准确性”指标的实现原理分析

3
使用TensorFlow Keras构建分类器时,通常在编译阶段指定metrics=['accuracy']以监控模型准确率。
model = tf.keras.Model(...)
model.compile(optimizer=..., loss=..., metrics=['accuracy'])

无论模型输出逻辑或类别概率,以及是否期望使用独热编码向量或整数索引作为真实标签,本方法都能正确处理。但如果要使用交叉熵损失,则需要针对上述四种情况中的每一种进行不同的“loss”值传递:
1. 如果模型输出概率且真实标签是独热编码,则可以使用“loss='categorical_crossentropy'”。 2. 如果模型输出概率且真实标签是整数索引,则可以使用“loss='sparse_categorical_crossentropy'”。 3. 如果模型输出逻辑且真实标签是独热编码,则可以使用“loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True)”。 4. 如果模型输出逻辑且真实标签是整数索引,则可以使用“loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)”。
似乎仅仅指定loss='categorical_crossentropy'并不足以处理这四种情况,而指定metrics=['accuracy']则足够强健。

问题 当用户在模型编译步骤中指定 metrics=['accuracy'] 时,背后发生了什么,使得准确性计算可以正确执行,无论模型输出是对数还是概率,以及基础事实标签是独热编码向量还是整数索引?


我猜测logits和概率的情况很简单,因为预测的类别可以通过argmax以任一方式获得,但是我希望能够指出在TensorFlow 2源代码中实际执行计算的位置。
请注意,我目前正在使用 TensorFlow 2.0.0-rc1

编辑 在纯Keras中,metrics=['accuracy']Model.compile方法中被明确处理。

1个回答

4

找到了:这个问题在tensorflow.python.keras.engine.training_utils.get_metric_function中处理。特别地,会检查输出形状以确定使用哪个准确度函数。

为了详细说明,当前实现中Model.compile将度量处理委托给Model._compile_eagerly(如果正在执行即时模式)或直接进行处理。在任一情况下,都会调用Model._cache_output_metric_attributes,该函数调用未加权和加权度量的collect_per_output_metric_info。此函数循环遍历提供的指标,并对每个指标调用get_metric_function

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接