使用model.fit_generator如何获取混淆矩阵

14
我正在使用model.fit_generator来训练并获取我的二元(两类)模型的结果,因为我直接从文件夹中提供输入图像。在这种情况下,如何获得混淆矩阵(TP、TN、FP、FN),因为通常我使用sklearn.metrics的confusion_matrix命令来获得它,该命令需要预测和实际标签。但是在这里,我没有这两个标签。也许我可以通过predict=model.predict_generator(validation_generator)命令计算预测标签。但我不知道我的模型是如何从我的图像中获取输入标签的。我的输入文件夹的一般结构为:
train/
 class1/
     img1.jpg
     img2.jpg
     ........
 class2/
     IMG1.jpg
     IMG2.jpg
test/
 class1/
     img1.jpg
     img2.jpg
     ........
 class2/
     IMG1.jpg
     IMG2.jpg
     ........

而我的一些代码块是:

train_generator = train_datagen.flow_from_directory('train',  
        target_size=(50, 50),  batch_size=batch_size,
        class_mode='binary',color_mode='grayscale')  


validation_generator = test_datagen.flow_from_directory('test',
        target_size=(50, 50),batch_size=batch_size,
        class_mode='binary',color_mode='grayscale')

model.fit_generator(
        train_generator,steps_per_epoch=250 ,epochs=40,
        validation_data=validation_generator,
        validation_steps=21 )

因此,上述代码自动输入了两个类别,但我不知道它将哪个视为类别0,哪个视为类别1。

3个回答

5

我是用keras.utils.Sequence来管理的,具体如下。

from sklearn.metrics import confusion_matrix
from keras.utils import Sequence


class MySequence(Sequence):
    def __init__(self, *args, **kwargs):
        # initialize
        # see manual on implementing methods

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        # return index-th complete batch


# create data generator
data_gen = MySequence(evaluation_set, batch_size=10) 

n_batches = len(data_gen)

confusion_matrix(
    np.concatenate([np.argmax(data_gen[i][1], axis=1) for i in range(n_batches)]),    
    np.argmax(m.predict_generator(data_gen, steps=n_batches), axis=1) 
)

这个实现的类会以元组的形式返回数据批次,这样就不需要将它们全部保存在内存中。请注意,这必须在 __getitem__ 中实现,并且该方法必须为相同参数返回相同的批次。

不幸的是,这段代码迭代了两次数据:第一次,它从返回的批次创建真实答案的数组,第二次则调用模型的 predict 方法。


4
probabilities = model.predict_generator(generator=test_generator)

会给我们一组概率。
y_true = test_generator.classes

我们需要真实标签。

由于这是一个二元分类问题,您需要找到预测标签。为了做到这一点,您可以使用

y_pred = probabilities > 0.5

然后我们在测试数据集上有真实标签和预测标签。因此,混淆矩阵如下:

font = {
'family': 'Times New Roman',
'size': 12
}
matplotlib.rc('font', **font)
mat = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(conf_mat=mat, figsize=(8, 8), show_normed=False)

你是怎么确定这是一个二元分类的?抱歉,我对这一切都很新。 - Prany
我注意到在问题中。 - Shirantha Madusanka

3

您可以通过调用train_generatorvalidation_generator对象的属性class_indices来查看从类名到类索引的映射,例如:

train_generator.class_indices


你的帮助很有价值。当我输入 train_generator.class_indicesvalidation_generator.class_indices 时,它给出输出 {'class1': 0, 'class2': 1}。假设对于测试数据,我有两个类别各10个样本(共20个),那么我的输入类别将会是 [0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]?我将从 predict=model.predict_generator(validation_generator) 中获取输出类别,然后就可以得到混淆矩阵了。 - Hitesh
1
@Hitesh,我有点困惑你的问题。如果你问的是class_indices的输出实际上如何映射到底层标签,请查看我的视频,在这里解释了这一点。输出在4:20左右显示。https://youtu.be/pZoy_j3YsQg 此外,您从predict_generator获得的输出将是您用于混淆矩阵的内容。 - blackHoleDetector
你能稍微详细解释一下如何从predict_generator转换到混淆矩阵吗? - bw4sz
一旦您使用predict_generator获得了预测结果,以及这些预测结果的标签,您可以使用sklearn的混淆矩阵绘制结果。我在这里演示了这个过程:https://youtu.be/km7pxKy4UHU 注意,在该视频中,我们使用predict()而不是predict_generator()进行预测,但概念是相同的。在观看完第一个视频后,如果您想查看predict_generator()的过程,请参考https://youtu.be/bfQBPNDy5EM,大约在5:30处。第二个视频没有像第一个视频那样详细地演示创建矩阵的整个过程。 - blackHoleDetector

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接