解决混淆矩阵绘图中的线条问题

6

我试图绘制如下所示的混淆矩阵

cm  = confusion_matrix(testY.argmax(axis=1), predictions.argmax(axis=1))

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=lb.classes_)
disp = disp.plot(include_values=True, cmap='viridis', ax=None, xticks_rotation='horizontal')

plt.show()

结果:

我得到的混淆矩阵

如您所见,它显示的是方框的轴线而不是轮廓。由于轴线的存在,我看不到黄色方框外面的数字。我对绘图不熟悉,因此无法找出需要更改的内容。

我期望的是: 期望的混淆矩阵

找到解决方案

plt.tick_params(axis=u'both', which=u'both',length=0)
plt.grid(b=None)

您可以使用plot_confusion_matrix - Ahmet
尝试使用sns绘图,它简单易用且基于matplotlib编写。 - Suryaveer Singh
@Ahx plot_confusion_matrix 将被弃用,建议使用 ConfusionMatrixDisplay。 - erp_da
5个回答

5

关闭网格

例如:

import matplotlib.pyplot as plt
fig, _ = plt.subplots(nrows=1, figsize=(10,10))
ax = plt.subplot(1, 1, 1)
ax.grid(False)

...

disp = ConfusionMatrixDisplay(...)
_ = disp.plot(..., ax=ax, ...)

3
cm  = confusion_matrix(testY.argmax(axis=1), predictions.argmax(axis=1))

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=lb.classes_)
disp = disp.plot(include_values=True, cmap='viridis', ax=None, xticks_rotation='horizontal')
plt.grid(False)
plt.show()

1

我在最早的一些单元格中使用了plt.rcParams['axes.grid'] = True(用于绘制另外的matplotlib图表)。因此,在ConfusionMatrixDisplay之前,我将它关闭了。

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
plt.rcParams['axes.grid'] = True

...

plt.rcParams['axes.grid'] = False
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))

disp_rfc.plot(ax = ax[0], cmap='coolwarm')
disp_cbc.plot(ax = ax[1], cmap='coolwarm')

plt.show()

0

plot()函数中修改您的cmap参数。它代表着将整数值与颜色进行颜色映射。

请检查

https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html

更多细节请参考。

作为答案。

cm  = confusion_matrix(testY.argmax(axis=1), predictions.argmax(axis=1))

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=lb.classes_)
disp = disp.plot(include_values=True, cmap='Blues', ax=None, xticks_rotation='horizontal')

plt.show()

这将解决颜色问题,但不是他实际遇到的问题,你没有理解实际问题。 - Suryaveer Singh
他在处理颜色和坐标轴方面遇到了问题,我提供了解决颜色问题的方法,因为他已经知道如何修复坐标轴。 - Harut Hunanyan

0
你展示的例子图表是通过sns plot绘制的。你可以使用sns heatmap来绘制你的矩阵。
import seaborn as sns
categories = lb.classes_
sns.heatmap(cm, annot=True,categories =categories, cmap='Blues')

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接