Matplotlib matshow带有许多字符串标签

11
今天我尝试从我的分类模型绘制混淆矩阵。在查看了一些页面后,我发现pyplot中的matshow可以帮助我。
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(cm)
    plt.title(title)
    fig.colorbar(cax)
    if labels:
        ax.set_xticklabels([''] + labels)
        ax.set_yticklabels([''] + labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

如果我只有很少的标签,它就能很好地工作

y_true = ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'a', 'c', 'd', 'b', 'a', 'b', 'a']
y_pred = ['a', 'b', 'c', 'd', 'a', 'b', 'b', 'a', 'c', 'a', 'a', 'a', 'a', 'a']
labels = list(set(y_true))
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, labels=labels)

enter image description here

但是如果我有很多标签,有些标签就无法正确显示。

y_true = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
y_pred = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
labels = list(set(y_true))
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, labels=labels)

enter image description here

我的问题是如何在matshow图中显示所有标签?我尝试了一些类似于fontdict的东西,但它仍然无法正常工作。
2个回答

7
您可以使用matplotlib.ticker模块来控制刻度的频率。

在这种情况下,您希望设置每个1的倍数的刻度,因此我们可以使用MultipleLocator

在调用plt.show()之前添加以下两行:

ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

这将会为你的y_truey_pred中的每个字母生成一个勾号和标签。

我还修改了你的matshow调用,以使用你在函数调用中指定的颜色映射。

cax = ax.matshow(cm,cmap=cmap)

在此输入图片描述

为了完整性,你的整个函数将如下所示:

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import matplotlib.ticker as ticker

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    # I also added cmap=cmap here, to make use of the 
    # colormap you specify in the function call
    cax = ax.matshow(cm,cmap=cmap)
    plt.title(title)
    fig.colorbar(cax)
    if labels:
        ax.set_xticklabels([''] + labels)
        ax.set_yticklabels([''] + labels)

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig('confusionmatrix.png')

它成功了。谢谢Tom。另一个问题,你知道如何在matshow图中绘制混淆矩阵的值吗? - Vu Anh
最好提出一个新问题,更详细地说明您想要什么。 - tmdavison
当我在 ax.xaxis.set_major_locator 后面加上 ax.matshow(data, cmap=plt.cm.Blues, alpha=0.7) 时,它没有起作用;这只是给其他人的一个提示。 - Alex Punnen

6
你可以使用 xticks 方法来指定标签。你的函数将会是这样的(对以上答案中的函数进行修改):
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    # I also added cmap=cmap here, to make use of the 
    # colormap you specify in the function call
    cax = ax.matshow(cm,cmap=cmap)
    plt.title(title)
    fig.colorbar(cax)
    if labels:
        plt.xticks(range(len(labels)), labels)
        plt.yticks(range(len(labels)), labels)

    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig('confusionmatrix.png')

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接