使用TensorFlow Keras构建分类器时,通常在编译阶段指定
无论模型输出逻辑或类别概率,以及是否期望使用独热编码向量或整数索引作为真实标签,本方法都能正确处理。但如果要使用交叉熵损失,则需要针对上述四种情况中的每一种进行不同的“loss”值传递:
1. 如果模型输出概率且真实标签是独热编码,则可以使用“loss='categorical_crossentropy'”。 2. 如果模型输出概率且真实标签是整数索引,则可以使用“loss='sparse_categorical_crossentropy'”。 3. 如果模型输出逻辑且真实标签是独热编码,则可以使用“loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True)”。 4. 如果模型输出逻辑且真实标签是整数索引,则可以使用“loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)”。
似乎仅仅指定
我猜测logits和概率的情况很简单,因为预测的类别可以通过argmax以任一方式获得,但是我希望能够指出在TensorFlow 2源代码中实际执行计算的位置。
请注意,我目前正在使用 TensorFlow 2.0.0-rc1。
metrics=['accuracy']
以监控模型准确率。model = tf.keras.Model(...)
model.compile(optimizer=..., loss=..., metrics=['accuracy'])
无论模型输出逻辑或类别概率,以及是否期望使用独热编码向量或整数索引作为真实标签,本方法都能正确处理。但如果要使用交叉熵损失,则需要针对上述四种情况中的每一种进行不同的“loss”值传递:
1. 如果模型输出概率且真实标签是独热编码,则可以使用“loss='categorical_crossentropy'”。 2. 如果模型输出概率且真实标签是整数索引,则可以使用“loss='sparse_categorical_crossentropy'”。 3. 如果模型输出逻辑且真实标签是独热编码,则可以使用“loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True)”。 4. 如果模型输出逻辑且真实标签是整数索引,则可以使用“loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)”。
似乎仅仅指定
loss='categorical_crossentropy'
并不足以处理这四种情况,而指定metrics=['accuracy']
则足够强健。
问题
当用户在模型编译步骤中指定 metrics=['accuracy']
时,背后发生了什么,使得准确性计算可以正确执行,无论模型输出是对数还是概率,以及基础事实标签是独热编码向量还是整数索引?
我猜测logits和概率的情况很简单,因为预测的类别可以通过argmax以任一方式获得,但是我希望能够指出在TensorFlow 2源代码中实际执行计算的位置。
请注意,我目前正在使用 TensorFlow 2.0.0-rc1。
编辑
在纯Keras中,metrics=['accuracy']
在Model.compile
方法中被明确处理。