模型测试准确率为98%,但混淆矩阵不准确。

5

我训练了一个二元分类模型,测试准确率达到了98%,训练准确率达到了99%。

今天我想要计算混淆矩阵,并使用以下代码进行计算。

model = load_model("model.h5")

testGenerator = ImageDataGenerator(rotation_range=5,
                                width_shift_range=0.2,
                                height_shift_range=0.2,
                                horizontal_flip=False,
                                fill_mode='nearest'
                                )   

testData = testGenerator.flow_from_directory(
                                'Location', 
                                target_size=(74,448),                                                 
                                batch_size=15,
                                class_mode='binary',
                                shuffle=False
                                )

proba = model.predict_generator(testData,steps=3000//15)
y_true = np.array([0] * 1482 + [1] * 1482 )
y_pred = proba > 0.5
print(confusion_matrix(y_true, y_pred))

我收到了这个混淆矩阵:

Confusion Matrix

如sklearn所说:

enter image description here

它表明假阴性和假阳性都很高。因为我的测试准确率达到了98%,这怎么可能呢?我已经多次使用该模型生成预测(使用model.predict()函数)并手动检查过,每次都给出了正确的分类。

有什么办法可以获得准确的结果吗?


你的真实数据是否像你的y_true变量一样分布? - CupinaCoffee
1
@CupinaCoffee 我已经设置了 shuffle=false 来实现这个功能。 - Samitha Nanayakkara
好的。请在 https://github.com/keras-team/keras/issues/3477 中查看来自soumendra的评论。 - CupinaCoffee
@CupinaCoffee 谢谢。我之前看过那篇文章,但他使用了train_generator.class_indices,而我没有,因为我已经训练好了模型。 - Samitha Nanayakkara
听起来你的初始模型在训练过程中可能出现了过拟合。能否描述一下你训练模型的过程? - vielkind
@vealkind 我使用了“Earlystopping”方法来确保不会发生过拟合。此外,我的训练、验证和测试数据集从未有相同的数据。还使用了dropout层。 - Samitha Nanayakkara
1个回答

0

让我们从结尾开始。消息“TypeError: unhashable type: 'numpy.ndarray'”意味着您不能将numpy.ndarray用作字典键,因为它不是不可变对象。首先将其转换为tuple或其他不可变对象。

关于您的混淆矩阵,我打赌生成器以不可预测的顺序从文件夹中加载文件,但您已将y_true设置为1482个zeros和1482个ones--这可能与生成器产生的文件的顺序匹配,也可能不匹配。因此,您得到了有趣的结果。


我最近修复了那个错误。 但是,使用 predict_generator() 生成的预测仍然不如 predict() 函数准确。 另外,我已设置 shuffle=false。那么它怎么会以不可预测的方式考虑它们呢? - Samitha Nanayakkara
@Sam94,那就不要使用predict_generator()了吗? - lenik
生成混淆矩阵并进行统计分析,最好有大量样本,对吧?(至少500个)。但我只有100张图像。因此,我需要增强它们并进行预测。那么,还有哪些可能的解决方案可供我选择呢? - Samitha Nanayakkara
@Sam94 我并不反对使用生成器,只是你的标签不适用于数据集。在你手头的数据上计算混淆矩阵。如果结果显著改善,就修正你生成数据和基准真值的方式。 - lenik
我使用了我已有的图像而没有生成图像,得到了准确的结果。我的意思是使用predict()函数,而不是predict_generator()函数。因此,我需要纠正这个问题。这就是为什么我寻求一些帮助的原因。 - Samitha Nanayakkara
@Sam94在你的代码中不要使用生成器,单独创建一些图像并保存到磁盘上,确保y_true与数据匹配。 - lenik

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接