Seaborn 混淆矩阵(热力图)2 种颜色方案(正确对角线 vs 错误其他部分)

8

背景

在混淆矩阵中,对角线表示预测标签与正确标签匹配的情况。因此,对角线是好的,而所有其他单元格都是不好的。为了让非专业人士更加清楚混淆矩阵中的好与坏,我想给对角线和其他部分使用不同的颜色进行区分。我想使用 Python 和 Seaborn 来实现这一点。

基本上,我正在尝试实现 R 中这个问题所做的事情 (ggplot2 Heatmap 2 Different Color Schemes - Confusion Matrix: Matches in Different Color Scheme than Missclassifications)。

使用 heatmap 的普通 Seaborn 混淆矩阵

import numpy as np
import seaborn as sns

cf_matrix = np.array([[50, 2, 38],
                      [7, 43, 32],
                      [9,  4, 76]])

sns.heatmap(cf_matrix, annot=True, cmap='Blues')  # cmap='OrRd'

导致这个图片的结果是:

Seaborn Confusion Matrix with colormap 'Blues'

目标

我想使用例如cmap='OrRd'来着色非对角线单元格。因此,我想会有2个色条,一个蓝色代表对角线,一个代表其他单元格。最好两个色条的值相匹配(所以都是0-70而不是0-70和0-40)。 我该如何实现这个目标?

以下图片是使用照片编辑软件制作的,并非代码生成:

Desired Confusion Matrix color scheme

3个回答

15

你可以在调用heatmap()时使用mask=来选择要显示的单元格。使用两个不同的掩码分别用于对角线和非对角线单元格,即可获得所需的输出:

import numpy as np
import seaborn as sns

cf_matrix = np.array([[50, 2, 38],
                      [7, 43, 32],
                      [9,  4, 76]])

vmin = np.min(cf_matrix)
vmax = np.max(cf_matrix)
off_diag_mask = np.eye(*cf_matrix.shape, dtype=bool)

fig = plt.figure()
sns.heatmap(cf_matrix, annot=True, mask=~off_diag_mask, cmap='Blues', vmin=vmin, vmax=vmax)
sns.heatmap(cf_matrix, annot=True, mask=off_diag_mask, cmap='OrRd', vmin=vmin, vmax=vmax, cbar_kws=dict(ticks=[]))

这里输入图片描述

如果你想要更加精细,请使用GridSpec创建坐标轴以获得更好的布局:

import numpy as np import seaborn as sns

fig = plt.figure()
gs0 = matplotlib.gridspec.GridSpec(1,2, width_ratios=[20,2], hspace=0.05)
gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(1,2, subplot_spec=gs0[1], hspace=0)

ax = fig.add_subplot(gs0[0])
cax1 = fig.add_subplot(gs00[0])
cax2 = fig.add_subplot(gs00[1])

sns.heatmap(cf_matrix, annot=True, mask=~off_diag_mask, cmap='Blues', vmin=vmin, vmax=vmax, ax=ax, cbar_ax=cax2)
sns.heatmap(cf_matrix, annot=True, mask=off_diag_mask, cmap='OrRd', vmin=vmin, vmax=vmax, ax=ax, cbar_ax=cax1, cbar_kws=dict(ticks=[]))

输入图像描述


可以通过注释颜色条来说明什么是好的,什么是坏的(在底部添加):cax2.set_title("X | O ", loc='right')。还可以添加一些额外的轴标签和标题:ax.set(xlabel='预测标签', ylabel='真实标签', title="混淆矩阵") - NumesSanguis

1

在一个相关问题中,有人问如何显示一个相关数据框,其中颜色条范围不包括对角线(即所有1的位置)。

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

iris = sns.load_dataset('iris').drop(columns=['species'])
corr_df = iris.corr()

plt.figure(figsize=(7,5))
ax = sns.heatmap(corr_df, annot=True, cmap='OrRd', mask=np.eye(len(corr_df)))
ax.set_yticklabels(ax.get_yticklabels(), va='center')
ax.patch.set_facecolor('skyblue')
ax.patch.set_edgecolor('white')
ax.patch.set_hatch('xx')

plt.tight_layout()
plt.show()

sns.heatmap for a correlation matrix without the diagonal


1
你可以先使用颜色映射“OrRd”绘制热力图,然后再叠加一个颜色映射为“Blues”的热力图,将上三角和下三角的值替换为NaN,参见以下示例:
def diagonal_heatmap(m):

    vmin = np.min(m)
    vmax = np.max(m)    
    
    sns.heatmap(cf_matrix, annot=True, cmap='OrRd', vmin=vmin, vmax=vmax)

    diag_nan = np.full_like(m, np.nan, dtype=float)
    np.fill_diagonal(diag_nan, np.diag(m))
    
    sns.heatmap(diag_nan, annot=True, cmap='Blues', vmin=vmin, vmax=vmax, cbar_kws={'ticks':[]}) 




cf_matrix = np.array([[50, 2, 38],
                      [7, 43, 32],
                      [9,  4, 76]])

diagonal_heatmap(cf_matrix)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接