使用混淆矩阵理解多标签分类器

9
我有一个包含12个类别的多标签分类问题。我使用Tensorflow的slim来训练模型,使用在ImageNet上预训练的模型。以下是每个类别在训练和验证中的出现百分比。
            Training     Validation
  class0      44.4          25
  class1      55.6          50
  class2      50            25
  class3      55.6          50
  class4      44.4          50
  class5      50            75
  class6      50            75
  class7      55.6          50
  class8      88.9          50
  class9     88.9           50
  class10     50            25
  class11     72.2          25

问题在于模型没有收敛,在验证集上的 ROC 曲线下面积(Az)很差,类似于:
               Az 
  class0      0.99
  class1      0.44
  class2      0.96  
  class3      0.9
  class4      0.99
  class5      0.01
  class6      0.52
  class7      0.65
  class8      0.97
  class9     0.82
  class10     0.09
  class11     0.5
  Average     0.65

我不知道为什么有些类表现良好,而其他类则不行。我决定深入了解神经网络学习的细节。我知道混淆矩阵仅适用于二元或多类分类。因此,为了能够绘制它,我必须将问题转换为一系列多类分类。尽管该模型使用sigmoid对每个类提供预测,但在下面的混淆矩阵中的每个单元格中,我显示了存在行中的类且列中不存在的图像的概率平均值(通过应用tensorflow预测的sigmoid函数获得)。这是在验证集图像上应用的。我认为这样可以更详细地了解模型正在学习什么。我只是为了展示目的而圈出了对角线元素。

enter image description here

我的解释是:
  1. 当类别0和4存在时,它们就会被检测到,并且在不存在时不会被检测到。这意味着这些类别被很好地检测到。
  2. 类别2、6和7总是被检测为不存在。这不是我要找的。
  3. 类别3、8和9总是被检测为存在。这不是我要找的。此规则也适用于类别11。
  4. 当类别5不存在时,则被检测为存在,当其存在时,则被检测为不存在。这是一种相反的检测结果。
  5. 类别3和10:我认为我们不能从这两个类别中提取太多信息。

我的问题是解释...我不确定问题出在哪里,也不确定数据集是否存在偏差导致出现这种结果。我还想知道是否有一些指标可以帮助解决多标签分类问题?您能否与我分享对这样的混淆矩阵的解释?以及接下来要看什么/在哪里寻找?对其他指标的建议也将非常有帮助。

谢谢。

编辑:

我将问题转化为多类分类,对于每一对类别(例如0,1),计算概率(类别0,类别1),表示为p(0,1): 我使用工具1的预测结果来处理存在工具0但不存在工具1的图像,并通过应用sigmoid函数将其转换为概率,然后展示这些概率的平均值。对于p(1,0),我做同样的处理,但现在是针对工具0,使用存在工具1但不存在工具0的图像。对于p(0,0),我使用所有存在工具0的图像。考虑上面图像中的p(0,4),N/A表示不存在工具0存在但不存在工具4存在的图像。
以下是2个子集的图像数量:
  1. 169320张训练图像
  2. 37440张验证图像
以下是在训练集上计算的混淆矩阵(与之前描述的验证集相同的方式计算),但这次颜色编码是用于计算每个概率的图像数量: enter image description here 编辑: 为了进行数据增强,我对每个输入图像进行随机平移、旋转和缩放。此外,以下是一些关于工具的信息:
class 0 shape is completely different than the other objects.
class 1 resembles strongly to class 4.
class 2 shape resembles to class 1 & 4 but it's always accompanied by an object different than the others objects in the scene. As a whole, it is different than the other objects.
class 3 shape is completely different than the other objects.
class 4 resembles strongly to class 1
class 5 have common shape with classes 6 & 7 (we can say that they are all from the same category of objects)
class 6 resembles strongly to class 7
class 7 resembles strongly to class 6
class 8 shape is completely different than the other objects.
class 9 resembles strongly to class 10
class 10 resembles strongly to class 9
class 11 shape is completely different than the other objects.

编辑后: 以下是下面提出的代码在训练集上的输出:

Avg. num labels per image =  6.892700212615167
On average, images with label  0  also have  6.365296803652968  other labels.
On average, images with label  1  also have  6.601033718926901  other labels.
On average, images with label  2  also have  6.758548914659531  other labels.
On average, images with label  3  also have  6.131520940484937  other labels.
On average, images with label  4  also have  6.219187208527648  other labels.
On average, images with label  5  also have  6.536933407946279  other labels.
On average, images with label  6  also have  6.533908387864367  other labels.
On average, images with label  7  also have  6.485973817793214  other labels.
On average, images with label  8  also have  6.1241642788920725  other labels.
On average, images with label  9  also have  5.94092288040875  other labels.
On average, images with label  10  also have  6.983303518187239  other labels.
On average, images with label  11  also have  6.1974066621953945  other labels.

对于验证集:

Avg. num labels per image =  6.001282051282051
On average, images with label  0  also have  6.0  other labels.
On average, images with label  1  also have  3.987080103359173  other labels.
On average, images with label  2  also have  6.0  other labels.
On average, images with label  3  also have  5.507731958762887  other labels.
On average, images with label  4  also have  5.506459948320414  other labels.
On average, images with label  5  also have  5.00169779286927  other labels.
On average, images with label  6  also have  5.6729452054794525  other labels.
On average, images with label  7  also have  6.0  other labels.
On average, images with label  8  also have  6.0  other labels.
On average, images with label  9  also have  5.506459948320414  other labels.
On average, images with label  10  also have  3.0  other labels.
On average, images with label  11  also have  4.666095890410959  other labels.
评论: 我认为这不仅与分布之间的差异有关,因为如果模型能够很好地概括类别10(意味着对象在训练过程中被正确识别,就像类别0一样),那么验证集上的准确性就足够好了。我的意思是,问题在于训练集本身以及如何构建它,而不仅仅是两种分布之间的差异。可能是:类别存在的频率或对象强烈相似(例如类别10与类别9非常相似)或数据集内部存在偏差或薄对象(可能代表输入图像中1或2%的像素,如类别2)。我并不是说问题就是其中之一,但我只是想指出,我认为它不仅仅是两种分布之间的差异。

1
你能否详细解释一下矩阵中的值是如何计算的?N/A代表什么?是否存在除以0的情况?你的训练集和测试集有多大?你是否还有关于训练数据中哪些类别共同出现的信息(例如,如果绘制一个热力图,它是否与你的混淆矩阵相似)? - Dennis Soemers
@DennisSoemers,我编辑了我的问题以包含更多细节。 - Maystro
我对目标类别感到困惑。每个图像可以有多个目标类别吗?我认为这是一个“多标签分类问题”。你在神经网络中使用什么损失函数?在这里看一些不同的选项:https://en.wikipedia.org/wiki/Multi-label_classification#Statistics_and_evaluation_metrics - KPLauritzen
你的网络输出是什么?它不是已经为每个标签给出了一个“概率”(在[0,1]范围内的数字)吗?如果是这样,我不认为我理解为什么要应用额外的sigmoid来获取混淆矩阵中的数字。难道你不能直接取平均值吗? - Dennis Soemers
@KPLauritzen。是的,这是一个多标签分类问题。每个图像可以有零到n个类别。我使用sigmoid作为损失函数。 - Maystro
@DennisSoemers,不是概率,我得到的是logits(每个标签的实数),我必须应用sigmoid将其转换为概率。 - Maystro
1个回答

8

输出校准

首先需要认识到的一点是,神经网络的输出可能存在较差的校准性。我的意思是,它对不同实例的输出可能会导致良好的排名(带有标签 L 的图像往往具有比没有标签 L 的图像更高的该标签得分),但这些得分并不能始终可靠地解释为概率(它可能会给出非常高的得分,如0.9,给没有标签的实例,而只是给带有标签的实例更高的得分,如0.99)。我想这是否会发生取决于你选择的损失函数等因素。

欲了解更多信息,请参见: https://arxiv.org/abs/1706.04599


逐个处理所有类别

类别 0: AUC(曲线下面积)= 0.99。这是一个非常好的分数。混淆矩阵中的第0列看起来也很好,因此这里没有问题。

类别 1: AUC = 0.44。这相当糟糕,低于0.5,如果我没有弄错,这基本上意味着针对此标签预测与网络预测相反可能更好。

查看混淆矩阵中的第一列,它在各处得分几乎都相同。对我来说,这表明网络没有成功地学习到该类别的许多信息,而只是根据训练集中包含该标签的图像百分比(55.6%)“猜测”。由于此百分比在验证集中降至50%,因此这种策略确实意味着它会比随机稍微差一点。然而,第1行仍然是该列中所有行中最高的数字,因此它似乎至少学到了一点,但不多。

类别 2: AUC = 0.96。这非常好。

你对这个类别的解释是,它总是被预测为不存在,基于整个列的浅色阴影。但我认为这种解释是不正确的。请注意,它在对角线上有一个得分> 0,在该列的其他位置都为0。它在该行可能具有相对较低的得分,但很容易与该列中的其他行区分开来。您可能只需要将选择是否存在该标签的阈值设置得相对较低。我怀疑这是由上述校准问题引起的。

这就是为什么AUC实际上非常好的原因;可以选择一个阈值,使得大多数得分高于阈值的实例正确地具有标签,并且大多数低于阈值的实例正确地不具有标签。但是该阈值可能不是0.5,这是您假设良好校准时可能期望的阈值。绘制此特定标签的ROC曲线可能有助于确定阈值应该在哪里。
类别3:AUC = 0.9,相当不错。
您将其解释为始终被检测到存在,并且混淆矩阵确实在列中有很多高数字,但是AUC很好,并且对角线上的单元格确实具有足够高的值,以便可以轻松地与其他单元格区分开来。我怀疑这是类2的类似情况(只是翻转了,到处都是高预测,因此需要高阈值才能做出正确的决策)。
如果您想确定是否可以确保精选的阈值确实可以正确地将大多数“阳性”(具有类3的实例)从大多数“阴性”(没有类3的实例)中正确拆分,则需要根据标签3的预测分数对所有实例进行排序,然后浏览整个列表,在每对连续条目之间计算准确性,如果您决定在那里放置阈值,则可以获得验证集,并选择最佳阈值。
类别4:与类别0相同。
类别5:AUC = 0.01,显然很糟糕。也同意您对混淆矩阵的解释。很难确定为什么它在这里表现如此糟糕。也许这是一种难以识别的对象?可能还存在一些过度拟合(从您的第二个矩阵中的列来判断,在训练数据中有0个假阳性,尽管也有其他类别出现这种情况)。
从训练到验证数据,标签5图像的比例增加也可能没有帮助。这意味着在训练期间,网络在此标签上表现良好的重要性不如在验证期间。
类别6:AUC = 0.52,仅比随机略好。
根据第一个矩阵中的第6列来判断,这实际上可能是类2的类似情况。但是,如果我们考虑AUC,看起来它也无法很好地学习排名实例。与类5类似,只是没有那么糟糕。此外,再次,训练和验证分布非常不同。
类别7:AUC = 0.65,相当平均。显然不像类2那样好,但也不像您从矩阵中可能解释的那样糟糕。
类别8:AUC = 0.97,非常好,与类别3相似。

第9类:AUC = 0.82,不错但也不算很好。矩阵中该列有很多深色单元格,数字非常接近,所以我认为AUC值非常好。它几乎出现在训练数据的每个图像中,因此它被预测为经常出现并不奇怪。也许有些非常暗的单元格只基于少量图像?这将是有趣的内容需要进一步研究。

第10类:AUC = 0.09,非常糟糕。矩阵对角线上的0值令人担忧(您的数据标记正确吗?)。根据第一个矩阵的第10行,它似乎经常被误认为是3和9类别(棉花和主要切口刀看起来很像次要切口刀吗?)。可能还存在一些对训练数据过度拟合的情况。

第11类:AUC = 0.5,与随机一样。性能差(矩阵中得分明显过高)很可能是因为此标签存在于大多数训练图像中,但只存在于少数验证图像中。


还需绘制/测量什么?

为了更深入地了解您的数据,我建议首先绘制每个类别共同出现的频率热图(分别针对训练和验证数据)。单元格(i,j)将按包含标签i和j的图像比例着色。这将是一个对称的图,对角线上的单元格根据问题中的第一组数字进行着色。比较这两个热图,看看它们在哪些方面非常不同,并查看是否可以帮助解释您的模型性能。

此外,对于每个数据集,知道每个图像平均拥有多少个标签以及针对每个标签,它与平均多少其他标签共享图像可能也很有用。例如,我怀疑在训练数据中,具有第10类标签的图像相对较少。如果识别到其他物品,这可能会使网络不去预测第10类标签,并且如果第10类标签在验证数据中与其他对象分享图像,则会导致性能不佳。由于伪代码可能更容易传达观点,因此打印以下内容可能会很有趣:

# Do all of the following once for training data, AND once for validation data    
tot_num_labels = 0
for image in images:
    tot_num_labels += len(image.get_all_labels())
avg_labels_per_image = tot_num_labels / float(num_images)
print("Avg. num labels per image = ", avg_labels_per_image)

for label in range(num_labels):
    tot_shared_labels = 0
    for image in images_with_label(label):
        tot_shared_labels += (len(image.get_all_labels()) - 1)
    avg_shared_labels = tot_shared_labels / float(len(images_with_label(label)))
    print("On average, images with label ", label, " also have ", avg_shared_labels, " other labels.")

对于单个数据集而言,这并没有提供太多有用的信息,但是如果你对训练和验证集进行此操作,你可以看出它们的分布如何在数字差异很大的情况下存在很大的不同。

最后,我有点担心您第一个矩阵中的某些列恰好有完全相同的平均预测出现在许多不同的行上。我不太确定是什么原因导致了这种情况,但可能值得调查。


如何改进?

如果您还没有这样做,我建议您研究一下如何使用数据增强来处理训练数据。由于您正在处理图像,您可以尝试添加现有图像的旋转版本到您的数据中。

对于您的多标签情况,其中目标是检测不同类型的对象,尝试简单地将一堆不同的图像(例如两个或四个图像)连接在一起可能也很有趣。然后,您可以将它们缩小到原始图像大小,并将并集分配为原始标签集。 在合并图像的边缘处,您会得到有趣的不连续性,我不知道是否会对您的多对象检测造成伤害。在我看来,这值得一试。


谢谢您提供这么详细的答案。我编辑了我的问题并添加了更多细节,我只有几个评论:我确实为训练/验证集生成了热图,但它们没有帮助。您能否在“还要绘制什么”部分中更详细地解释一下您的第二个建议? - Maystro
我还有一个关于你在训练/验证集中每个对象出现频率之间进行相关性的问题(例如类别5和6)来给出你的解释。从我的角度来看,我只是检查了训练集中每个对象的出现频率,因为这是模型用来推进的。 - Maystro
@Maystro 对于第一个问题,我增加了更多信息。至于训练/验证集中标签的频率。假设,以极端的例子来说,在训练集中某个标签出现了100%(或0%)的时间。那么模型将不会学到任何东西,它只会预测100%或0%,而不管图像的样子如何,这在测试数据中可能是错误的。对于你来说不会那么极端,但当你的训练和验证集具有非常不同的分布时,你仍然可以观察到这样的影响。 - Dennis Soemers
我明白你的意思。我编辑了我的问题,包括你提供的代码输出。对我来说看起来还不错...不确定你对这样的结果有什么看法? - Maystro
@Maystro 对于一些类别(但不是全部),它可以帮助诊断性能差的问题。例如,在训练中,显然类别为10的图像平均包含其他7个对象。在验证中,这突然只有3个。也许您的网络没有学会识别类别为10的对象,也许它只是学会了识别“具有大量对象的图像”。总的来说,这表明您的训练和验证集具有显着不同的分布,这通常意味着您不能合理地期望机器学习有出色的表现。 - Dennis Soemers
显示剩余5条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接