使用哪种损失函数和指标进行高负样本比例的多标签分类?

26

我正在训练一个多标签分类模型来检测服装的属性。我在Keras中使用迁移学习,重新训练vgg-19模型的最后几层。

总属性数为1000,其中大约99%是0。像准确率、精度、召回率等指标都无法正确评估,因为模型可以预测全部为零的情况下仍然获得非常高的分数。在损失函数方面,二元交叉熵、汉明损失等都没有效果。

我正在使用深度时尚数据集。

那么,我应该使用哪些指标和损失函数来正确评估我的模型呢?

5个回答

40

哈桑提出的建议是不正确的 - 分类交叉熵损失或Softmax损失是一个Softmax激活和一个交叉熵损失。如果我们使用这个损失函数,我们将训练卷积神经网络为每张图像输出C类别的概率,它被用于多类别分类

你所需的是多标签分类,因此你需要使用二元交叉熵损失或Sigmoid交叉熵损失。它是一个Sigmoid激活和一个交叉熵损失。与Softmax损失不同,它对每个向量组件(类)都是独立的,意味着为每个CNN输出向量组件计算的损失值不受其他组件值的影响。这就是为什么它被用于多标签分类的原因,其中一个元素属于某个类别的认识不应影响到决定另一个类别的判断。

现在对于处理类别不平衡问题,可以使用加权Sigmoid交叉熵损失。因此,你将基于正例的数量/比率来对错误预测进行惩罚。


真的。softmax 用于多类分类(这里是图10 - https://developers.google.com/machine-learning/guides/text-classification/step-4 -)... 感谢有关不平衡数据中多标签分类的建议。 - JeeyCi
只是一个问题:在使用softmax进行多类分类时,在反向传播和优化阶段,我们尝试更新权重以最大化估计的类概率并最小化错误估计概率。因此,如果我们对于任何样本有几个真实的y(例如[1,0,0,0,1,1]),在反向传播和优化期间,我们会操作真实类别的权重以最小化概率。这是正确的吗? - Mahdi Amrollahi
@MahdiAmrollahi 是的 - Ritwik

9
实际上,你应该使用 tf.nn.weighted_cross_entropy_with_logits。 它不仅适用于多标签分类,还有一个pos_weight可以更加重视你期望的正类。

1
Multi-class和binary-class分类决定了输出单元的数量,即最后一层神经元的数量。 Multi-label和single-Label决定了最后一层激活函数和损失函数的选择。 对于single-label,标准选择是使用Softmax和分类交叉熵;对于multi-label,则使用Sigmoid激活和二元交叉熵。
分类交叉熵:

enter image description here

二元交叉熵:

enter image description here

C 是类别的数量,m 是当前小批量中的示例数。 L 是损失函数,J 是代价函数。 您也可以在 这里 看到。 在损失函数中,您正在迭代不同的类别。 在代价函数中,您正在迭代当前小批量中的示例。


1

1
对于那些感到困惑的人,焦点损失是一种自定义损失函数,使得“良好”的预测对总体损失产生较小影响,而“错误”的预测与常规损失函数产生大致相同的影响。对于稀疏输出,这意味着你需要迫使网络直面所犯的错误,同时大部分忽略它所正确的(随机猜测多数输出的情况)。 - Austin

-1

我曾经处于和你类似的情况。

你可以在输出层使用softmax激活函数,并结合分类交叉熵检查其他指标,例如精度、召回率和F1分数,你可以按以下方式使用sklearn库:

from sklearn.metrics import classification_report

y_pred = model.predict(x_test, batch_size=64, verbose=1)
y_pred_bool = np.argmax(y_pred, axis=1)

print(classification_report(y_test, y_pred_bool))

关于培训阶段,据我所知有以下准确度指标

model.compile(loss='categorical_crossentropy'
              , metrics=['acc'], optimizer='adam')

如果有帮助的话,您可以使用matplotlib绘制训练阶段的损失和准确率的训练历史,方法如下:
hist = model.fit(x_train, y_train, batch_size=24, epochs=1000, verbose=2,
                 callbacks=[checkpoint],
                 validation_data=(x_valid, y_valid)

                 )
# Plot training & validation accuracy values
plt.plot(hist.history['acc'])
plt.plot(hist.history['val_acc'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接