如何消除混淆矩阵中的白线?

10
有人知道为什么这些白线将我的混淆矩阵分成四个部分吗?我已经更改了许多参数,但是无法弄清楚。唯一可以消除它们的方法是不完全标记块,即“0”,“1”等,但这显然不是我想要的结果。如果能有所帮助,将不胜感激。
代码:
def plot_confusion_matrix(cm,
                          target_names = ['1', '2', '3', '4'],
                          title = 'Confusion matrix',
                          cmap = None,
                          normalize = False):
    """
    given a sklearn confusion matrix (cm), make a nice plot

    Arguments
    ---------
    cm:           confusion matrix from sklearn.metrics.confusion_matrix

    target_names: given classification classes such as [0, 1, 2]
                  the class names, for example: ['high', 'medium', 'low']

    title:        the text to display at the top of the matrix

    cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
                  see http://matplotlib.org/examples/color/colormaps_reference.html
                  plt.get_cmap('jet') or plt.cm.Blues

    normalize:    If False, plot the raw numbers
                  If True, plot the proportions

    Usage
    -----
    plot_confusion_matrix(cm           = cm,                  # confusion matrix created by
                                                              # sklearn.metrics.confusion_matrix
                          normalize    = True,                # show proportions
                          target_names = y_labels_vals,       # list of names of the classes
                          title        = best_estimator_name) # title of graph

    Citiation
    ---------
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    """
    import matplotlib.pyplot as plt
    import numpy as np
    import itertools

    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize = (8, 6))
    plt.imshow(cm, interpolation = 'nearest', cmap = cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation = 0)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis = 1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment = "center",
                     color = "white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment = "center",
                     color = "white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.show()


plot_confusion_matrix(cm           = (confusion), 
                      normalize    = True,
                      target_names = ['1', '2', '3', '4'],
                      title        = "Confusion Matrix")

输出结果为:

在此输入图片描述


5
我想你只需要禁用网格。 - R2RT
对不起,但是你的意思是什么? - Chris Macaluso
3
我的意思是这些白色的线是由 pyplot 默认绘制的网格线,你需要将其禁用。我不知道具体的函数,但应该很容易找到一个。 - R2RT
5
谢谢!这让我朝着正确的方向前进了,我不知道它叫什么。需要一个 plt.grid(None) - Chris Macaluso
2个回答

15
plt.figure(figsize=(10,5))

plt.grid(False)

plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=False, title='Normalized confusion matrix')

1
如果您想格式化代码,请在三个反引号之间放置代码。如果您查看,会有一个链接向您展示如何在帖子中使用Markdown。请将您的代码放置在这里 - ThePyGuy
plt.grid(False)是我需要的,谢谢! - undefined

1
尝试重置它。
import seaborn as sns 
# Try resetting it 
sns.reset_orig() 

目前你的回答不够清晰,请[编辑]以添加更多细节,以帮助他人了解它如何解决所提出的问题。您可以在帮助中心找到有关编写良好答案的更多信息。 - Community

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接