多类多目标分类问题的最佳损失函数

3
我有一个分类问题,但不知道如何对其进行分类。据我所知,
多类别分类问题是指您有多个互斥的类别,并且数据集中的每个数据点只能被一个类别标记。例如,在水果图像分类任务中,标记为苹果的水果数据点不能是橙子,橙子不能是香蕉等等。在这种情况下,每个数据点只能是水果类别中的任意一种水果,并相应地标记。
而多标签分类是指您有多组互斥的类别,其中数据点可以同时被标记。例如,在汽车图像分类任务中,标记为轿车的汽车数据点不能是掀背车,掀背车不能是SUV等等。同时,同一辆车的数据点可以从VW、Ford、Mercedes等品牌中选择一个进行标记。因此,在这种情况下,汽车数据点从两组互斥的类别中标记。
如果我理解有误,请纠正我。现在,我的分类问题涉及多个类别,假设为A、B、C、D和E。每个数据点可以具有来自左侧显示的集合中一个或多个类别。
|-------------|----------|              |-------------|-----------------|
|      X      |     y    |              |      X      |    One-Hot-Y    |
|-------------|----------|              |-------------|-----------------|
|     DP1     |   A, B   |              |     DP1     | [1, 1, 0, 0, 0] |
|-------------|----------|              |-------------|-----------------|
|     DP2     |   C      |              |     DP2     | [0, 0, 1, 0, 0] |
|-------------|----------|              |-------------|-----------------|
|     DP3     |   B, E   |              |     DP3     | [0, 1, 0, 0, 1] |
|-------------|----------|              |-------------|-----------------|
|     DP4     |   A, C   |              |     DP4     | [1, 0, 1, 0, 0] |
|-------------|----------|              |-------------|-----------------|
|     DP5     |   D      |              |     DP5     | [0, 0, 0, 1, 0] |
|-------------|----------|              |-------------|-----------------|

我已经按照要求进行了翻译,以下是您需要的结果:

如上所示,我对训练标签进行了One-Hot编码。我的问题是:

  1. 为了优化One-Hot编码输出,我可以使用哪种损失函数(最好是PyTorch)来训练模型?
  2. 我们如何称呼这种分类问题?多标签或多类别?

感谢您的答案!

1个回答

4

我应该使用什么损失函数(最好使用PyTorch)来训练模型以优化One-Hot编码的输出?

您可以使用torch.nn.BCEWithLogitsLoss(或MultiLabelSoftMarginLoss,因为它们是等效的),然后查看这个方法的效果。这是标准方法,其他可能性是MultilabelMarginLoss

我们把这样的分类问题称为什么?多标签或多类?

这是多标签问题(因为多个标签可以同时存在)。在One-Hot编码中:

[1, 1, 0, 0, 0], [0, 1, 0, 0, 1] - multilabel
[0, 0, 1, 0, 0] - multiclass
[1], [0] - binary (special case of multiclass)

多分类问题不能有超过一个1,因为所有其他标签是相互排斥的。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接