如何在sklearn.metrics包中更改plot_confusion_matrix默认图像大小

14

我尝试使用sklearn.metrics.plot_confusion_matrix包在Jupyter笔记本中绘制混淆矩阵,但默认的图像尺寸有点小。我在绘图之前添加了plt.figure(figsize=(20, 20)),但输出文本显示'Figure size 1440x1440 with 0 Axes',图像大小并未改变。我该如何更改图像大小?

%matplotlib inline
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import plot_confusion_matrix
from matplotlib import pyplot as plt

plt.figure(figsize=(20, 20))
clf = GradientBoostingClassifier(random_state=42)
clf.fit(X_train, y_train)
plot_confusion_matrix(clf, X_test, y_test, cmap=plt.cm.Blues)
plt.title('Confusion matrix')
plt.show()

just like this image


12
fig, ax = plt.subplots(figsize=(20, 20)) 的意思是创建一个大小为 (20, 20) 的图形窗口和轴对象。然后,使用 plot_confusion_matrix(clf, X_test, y_test, cmap=plt.cm.Blues, ax=ax) 将分类器 clf 在测试数据集 X_test 和标签集 y_test 上的混淆矩阵绘制在这个轴上,并使用蓝色颜色映射。可参考文档 - BigBen
3个回答

25

我不知道为什么BigBen将那个作为评论而不是答案发布,但我差点错过了它。这里它已经被当作答案发表了,以便未来的旁观者不会犯我几乎犯的同样错误!

fig, ax = plt.subplots(figsize=(10, 10))
plot_confusion_matrix(your_model, X_test, y_test, ax=ax)

2
我使用 set_figwidthset_figheight 来指定图形的大小:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay  
import matplotlib.pyplot as plt 
disp = ConfusionMatrixDisplay.from_predictions(
                  [0,1,1,0,1], 
                  [0,1,0,1,0], 
                  labels=[1,0],
                  cmap=plt.cm.Blues,
                  display_labels=['Good','Bad'], 
                  values_format='',  
) 
fig = disp.ax_.get_figure() 
fig.set_figwidth(3)
fig.set_figheight(3)  

1

ConfusionMatrixDisplay 提供了比 plot_confusion_matrix 更多的控制和灵活性,用于可视化混淆矩阵。更多信息请参见:docs

from sklearn.metrics import ConfusionMatrixDisplay  
y_true = [0,1,1,0,1]
y_pred = [0,1,0,1,0]
labels = ['Good','Bad'] # 0: Good and 1: Bad
disp = ConfusionMatrixDisplay.from_predictions(
                                              y_true, 
                                              y_pred, 
                                              display_labels=labels, 
                                              cmap=plt.cm.Blues
                                              ) 
fig = disp.figure_
fig.set_figwidth(10)
fig.set_figheight(10) 
fig.suptitle('Plot of confusion matrix')

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接