类数为4,与目标名称数6不匹配。尝试指定标签参数。

8

当我尝试制作卷积神经网络模型的混淆矩阵时,遇到了一些问题。当我运行代码时,会返回一些错误,例如:

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

Traceback (most recent call last):

  File "<ipython-input-102-82d46efe536a>", line 1, in <module>
    print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

  File "G:\anaconda_installation_file\lib\site-packages\sklearn\metrics\classification.py", line 1543, in classification_report
    "parameter".format(len(labels), len(target_names))

ValueError: Number of classes, 4, does not match size of target_names, 6. Try specifying the labels parameter

我已经搜索了如何解决这个问题,但仍然没有找到完美的解决方案。 我完全是新手,能否有人帮帮我? 谢谢。

from sklearn.metrics import classification_report,confusion_matrix
import itertools

Y_pred = model.predict(X_test)
print(Y_pred)
y_pred = np.argmax(Y_pred, axis=1)
print(y_pred)

target_names = ['class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)']

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

欢迎来到SO,我们不是通过抛出所有代码来解决问题;错误后面的代码是多余的,与问题无关(它永远不会被执行),只会增加不必要的混乱(已删除)。请参阅如何创建一个最小、完整和可验证的示例 - desertnaut
3个回答

2
问题在于你有6个标签名:'class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)',但是你的混淆矩阵中只有4个类别。当你打印y_pred时,你会得到一些数字,其中包含0,1,2,3,或者当你打印y_test时,你会得到来自0,1,2,3的数字,应该删除以下内容以帮助解决问题: print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names)) 从你的代码中,某种程度上你没有6个预测/测试类别。
这里还有一个绘制混淆矩阵的示例: 如何绘制混淆矩阵?

2
你的问题表述应该更清晰!我会做出一些假设!
问题在于:
目标名称为 `['class 0(cardboard)', 'class 1(glass)', 'class 2(metal)', 'class 3(paper)', 'class 4(plastic)', 'class 5(trash)']` 共6个类别,而你的模型只能预测4个类别,导致混淆矩阵中只有4个类别(应该是6x6而不是6x4)。要纠正这个问题,只需提供类别标签。例如,如果预测变量中有3个标签,分别为1、2、3,则可以使用以下代码:
```python print(classification_report(y_true, y_pred, labels=[1, 2, 3])) ```
请参阅此处的文档。
PS: 1. 你的模型表现很差。 2. 你的数据集可能存在类别不平衡的问题。

1

只需添加 "labels=np.arange(0,len(class_names),1)" 就可以解决你的问题,例如:

classification_report(y_true, y_pred,labels=np.arange(0,len(class_names),1),target_names=class_names, digits=4,zero_division=0)


这真的有效!尝试指定标签参数! - Hong Cheng

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接