使用pandas绘制相关矩阵

361

我有一个由大量特征组成的数据集,因此分析相关矩阵变得非常困难。我想绘制一个相关矩阵,使用 pandas 库中的 dataframe.corr() 函数获得。pandas 库是否提供任何内置函数来绘制此矩阵?


1
相关的答案可以在这里找到 从pandas DataFrame中制作热图 - joelostblom
Seaborn的clustermap可能也是可视化相关矩阵的有趣方式:sns_plot = sns.clustermap(dataframe.corr(), cmap="rocket_r") - nim.py
19个回答

457
你可以使用matplotlib中的pyplot.matshow()函数:pyplot.matshow()
import matplotlib.pyplot as plt

plt.matshow(dataframe.corr())
plt.show()

编辑:

在评论中有人要求如何更改轴刻度标签。这是高级版本,绘制在更大的图形尺寸上,具有与数据框匹配的轴标签以及一个带有颜色条例的色标解释。

我包括如何调整标签的大小和旋转,并使用一个使色条和主图高度相同的图形比率。


编辑2: 由于df.corr()方法忽略非数字列,因此在定义x和y标签时应该使用.select_dtypes(['number'])以避免标签的不必要偏移(包含在下面的代码中)。

f = plt.figure(figsize=(19, 15))
plt.matshow(df.corr(), fignum=f.number)
plt.xticks(range(df.select_dtypes(['number']).shape[1]), df.select_dtypes(['number']).columns, fontsize=14, rotation=45)
plt.yticks(range(df.select_dtypes(['number']).shape[1]), df.select_dtypes(['number']).columns, fontsize=14)
cb = plt.colorbar()
cb.ax.tick_params(labelsize=14)
plt.title('Correlation Matrix', fontsize=16);

相关性图示例


1
@TomRussell 你有做 import matplotlib.pyplot as plt 吗? - joelostblom
13
你知道如何在图表上显示实际的列名吗? - WebQube
2
@Cecilia 我已经通过将“rotation”参数更改为“90”来解决了这个问题。 - Ikbel
1
@ikbelbenabdessamad 非常感谢! - Cecilia
4
如果列名比那些更长,x轴标签看起来会有点不对劲,在我的情况下,它们看起来偏移了一个刻度,这让我感到困惑。在plt.xticks调用中添加ha="left"可以解决这个问题,如果其他人也有这个问题:) 参见 https://dev59.com/GF4b5IYBdhLWcg3w1U3L - V. Déhaye
显示剩余4条评论

423

如果您的主要目标是可视化相关矩阵,而不是创建一个图表本身,方便的pandas 样式选项是一个可行的内置解决方案:

import pandas as pd
import numpy as np

rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(10, 10))
corr = df.corr()
corr.style.background_gradient(cmap='coolwarm')
# 'RdBu_r', 'BrBG_r', & PuOr_r are other good diverging colormaps

enter image description here

请注意,这需要在支持渲染HTML的后端中,比如JupyterLab笔记本中。


样式

您可以轻松限制数字精度(现在在pandas 2.* 中使用.format(precision=2)):

corr.style.background_gradient(cmap='coolwarm').set_precision(2)

enter image description here

如果你喜欢没有注释的矩阵,也可以完全去掉数字。
corr.style.background_gradient(cmap='coolwarm').set_properties(**{'font-size': '0pt'})

enter image description here

样式文档还包括更高级的样式指令,比如如何改变鼠标悬停在单元格上时的显示方式。
时间比较
在我的测试中,style.background_gradient()plt.matshow() 快4倍,比 sns.heatmap() 快120倍,对于一个10x10的矩阵。不幸的是,它的扩展性不如 plt.matshow():对于一个100x100的矩阵,两者花费的时间大致相同,而对于一个1000x1000的矩阵,plt.matshow() 快10倍。

保存

有几种可能的方法来保存样式化的数据框:

  • 通过追加render()方法并将输出写入文件,返回HTML。
  • 通过追加to_excel()方法以条件格式保存为.xslx文件。
  • 与imgkit结合使用以保存位图
  • 截屏(就像我在这里所做的那样)。

规范化整个矩阵的颜色(pandas >= 0.24)

通过设置axis=None,现在可以根据整个矩阵而不是每列或每行来计算颜色:

corr.style.background_gradient(cmap='coolwarm', axis=None)

enter image description here


单个角热力图

由于很多人正在阅读这个答案,我想我可以提供一个小技巧,来仅显示相关矩阵的一个角落。我发现这样更容易阅读,因为它去除了冗余信息。

# Fill diagonal and upper half with NaNs
mask = np.zeros_like(corr, dtype=bool)
mask[np.triu_indices_from(mask)] = True
corr[mask] = np.nan
(corr
 .style
 .background_gradient(cmap='coolwarm', axis=None, vmin=-1, vmax=1)
 .highlight_null(color='#f1f1f1')  # Color NaNs grey
 .format(precision=2))

enter image description here


3
如果有一种方法可以将其导出为图像,那就太好了! - Kristada673
2
谢谢!你肯定需要一个分散的调色板corr = df.corr() cm = sns.light_palette("green", as_cmap=True) cm = sns.diverging_palette(220, 20, sep=20, as_cmap=True) corr.style.background_gradient(cmap=cm).set_precision(2)``` - stallingOne
2
这些图形视觉效果很棒,但@Kristada673的问题非常相关,你会如何导出它们? - Erfan
1
@roudan 你可以使用 from IPython.display import display (或导入 display_html)然后在循环中使用 display(df)。https://ipython.readthedocs.io/en/stable/api/generated/IPython.display.html#IPython.display.display - joelostblom
1
这很棒,你还可以手动设置颜色限制,而不是使用数据范围,例如vmin=-1,vmax=1 - SpinUp __ A Davis
显示剩余18条评论

125

Seaborn的热力图版本:

import seaborn as sns
corr = dataframe.corr()
sns.heatmap(corr, 
            xticklabels=corr.columns.values,
            yticklabels=corr.columns.values)

17
Seaborn 热力图效果华丽,但在处理大型矩阵时性能较差。Matplotlib 的 matshow 方法更快。 - anilbey
7
Seaborn可以自动从列名推断刻度标签。 - Tulio Casagrande
1
如果让seaborn自动推断,似乎并不总是显示所有的刻度标签。 - janto
最好也包括将颜色归一化为-1到1,否则颜色将跨越从最低的相关性(可以是任何地方)到最高的相关性(对角线上的1)。 - Nuclear03020704

105
你可以使用seaborn绘制热力图或使用pandas绘制散点矩阵来观察特征之间的关系。
散点矩阵:
pd.scatter_matrix(dataframe, alpha = 0.3, figsize = (14,8), diagonal = 'kde');

如果您想将每个特征的偏度可视化,可以使用seaborn pairplots。
sns.pairplot(dataframe)

Sns热力图:

import seaborn as sns

f, ax = pl.subplots(figsize=(10, 8))
corr = dataframe.corr()
sns.heatmap(corr,
    cmap=sns.diverging_palette(220, 10, as_cmap=True),
    vmin=-1.0, vmax=1.0,
    square=True, ax=ax)

输出将是特征的相关性图。例如,请参见下面的示例。

enter image description here

杂货和洗涤剂之间的相关性很高。同样:

具有高相关性的产品:

  1. 杂货和洗涤剂。

具有中等相关性的产品:

  1. 牛奶和杂货
  2. 牛奶和清洁用品

具有低相关性的产品:

  1. 牛奶和熟食
  2. 冷冻和新鲜食品。
  3. 冷冻和熟食。

从配对图中可以观察到相同的关系集合或散点矩阵。但是,我们可以从这些数据中得出数据是否符合正态分布。

enter image description here

注意:以上是从数据中获取的相同图形,用于绘制热力图。

3
如果这是指matplotlib,我认为应该是.plt而不是.pl。 - ghukill
3
不一定。他可以将其称为 from matplotlib import pyplot as pl - Jeru Luke
如何在相关图中始终设置相关性界限为-1到+1? - Debashis Sahoo
1
很好的答案。如果有人遇到错误 AttributeError: module 'pandas' has no attribute 'scatter_matrix',请参考此问题以获取帮助。简而言之:使用 pd.plotting.scatter_matrix() - DaytaSigntist
应该使用pd.plotting.scatter_matrix(dataframe, alpha = 0.3, figsize = (14,8), diagonal = 'kde');而不是pd.scatter_matrix(dataframe, alpha = 0.3, figsize = (14,8), diagonal = 'kde'); - undefined

104

尝试使用此函数,该函数还会显示相关矩阵的变量名称:

def plot_corr(df,size=10):
    """Function plots a graphical correlation matrix for each pair of columns in the dataframe.

    Input:
        df: pandas DataFrame
        size: vertical and horizontal size of the plot
    """

    corr = df.corr()
    fig, ax = plt.subplots(figsize=(size, size))
    ax.matshow(corr)
    plt.xticks(range(len(corr.columns)), corr.columns)
    plt.yticks(range(len(corr.columns)), corr.columns)

9
plt.xticks(range(len(corr.columns)), corr.columns, rotation='vertical') 如果您想要在x轴上以垂直方向显示列名。 - nishant
另一个图形相关的事情,但是添加 plt.tight_layout() 对于较长的列名也可能很有用。 - user3017048

18

为了完整起见,截至2019年末,我所知道的最简单的解决方案是使用seaborn,如果使用Jupyter

import seaborn as sns
sns.heatmap(dataframe.corr())

13

惊讶地发现没有人提到更有能力、互动性更强、更易于使用的替代方案。

A)您可以使用plotly:

  1. 只需两行代码即可获得:

  2. 交互性,

  3. 平滑缩放,

  4. 基于整个数据帧而不是单独列的颜色,

  5. 轴上的列名和行索引,

  6. 缩放,

  7. 平移,

  8. 内置的一键式保存为PNG格式的功能,

  9. 自动缩放,

  10. 悬停比较,

  11. 显示值的气泡,因此热图看起来仍然很好,并且您可以在任何地方查看值:

import plotly.express as px
fig = px.imshow(df.corr())
fig.show()

输入图像描述

B) 您也可以使用Bokeh:

所有相同的功能,但稍微有些麻烦。但如果您不想选择plotly仍然想要所有这些功能,那么它仍然是值得的。

from bokeh.plotting import figure, show, output_notebook
from bokeh.models import ColumnDataSource, LinearColorMapper
from bokeh.transform import transform
output_notebook()
colors = ['#d7191c', '#fdae61', '#ffffbf', '#a6d96a', '#1a9641']
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
data = df.corr().stack().rename("value").reset_index()
p = figure(x_range=list(df.columns), y_range=list(df.index), tools=TOOLS, toolbar_location='below',
           tooltips=[('Row, Column', '@level_0 x @level_1'), ('value', '@value')], height = 500, width = 500)

p.rect(x="level_1", y="level_0", width=1, height=1,
       source=data,
       fill_color={'field': 'value', 'transform': LinearColorMapper(palette=colors, low=data.value.min(), high=data.value.max())},
       line_color=None)
color_bar = ColorBar(color_mapper=LinearColorMapper(palette=colors, low=data.value.min(), high=data.value.max()), major_label_text_font_size="7px",
                     ticker=BasicTicker(desired_num_ticks=len(colors)),
                     formatter=PrintfTickFormatter(format="%f"),
                     label_standoff=6, border_line_color=None, location=(0, 0))
p.add_layout(color_bar, 'right')

show(p)

输入图像描述这里


11

如果你的数据框叫做df,你可以简单地使用:

import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(15, 10))
sns.heatmap(df.corr(), annot=True)

10

您可以使用Matplotlib的imshow()方法

import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')

plt.imshow(X.corr(), cmap=plt.cm.Reds, interpolation='nearest')
plt.colorbar()
tick_marks = [i for i in range(len(X.columns))]
plt.xticks(tick_marks, X.columns, rotation='vertical')
plt.yticks(tick_marks, X.columns)
plt.show()

10
我认为有很多好的答案,但我添加这个答案是为了那些需要处理特定列并展示不同图表的人。
import numpy as np
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt

rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(18, 18))
df= df.iloc[: , [3,4,5,6,7,8,9,10,11,12,13,14,17]].copy()
corr = df.corr()
plt.figure(figsize=(11,8))
sns.heatmap(corr, cmap="Greens",annot=True)
plt.show()

enter image description here


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接