多重指标绘图

20

我有一些数据,我已经使用以下代码操作了数据框:

import pandas as pd
import numpy as np

data = pd.DataFrame([[0,0,0,3,6,5,6,1],[1,1,1,3,4,5,2,0],[2,1,0,3,6,5,6,1],[3,0,0,2,9,4,2,1],[4,0,1,3,4,8,1,1],[5,1,1,3,3,5,9,1],[6,1,0,3,3,5,6,1],[7,0,1,3,4,8,9,1]], columns=["id", "sex", "split", "group0Low", "group0High", "group1Low", "group1High", "trim"])
data

#remove all where trim == 0
trimmed = data[(data.trim == 1)]
trimmed

#create df with columns to be split
columns = ['group0Low', 'group0High', 'group1Low', 'group1High']
to_split = trimmed[columns]
to_split

level_group = np.where(to_split.columns.str.contains('0'), 0, 1)
# output: array([0, 0, 1, 1])
level_low_high = np.where(to_split.columns.str.contains('Low'), 'low', 'high')
# output: array(['low', 'high', 'low', 'high'], dtype='<U4')

multi_level_columns = pd.MultiIndex.from_arrays([level_group, level_low_high], names=['group', 'val'])
to_split.columns = multi_level_columns
to_split.stack(level='group')

sex = trimmed['sex']
split = trimmed['split']
horizontalStack = pd.concat([sex, split, to_split], axis=1)
horizontalStack

finalData = horizontalStack.groupby(['split', 'sex', 'group'])
finalData.mean()

我的问题是,如何使用ggplot或seaborn绘制平均数据,以便对于每个“split”级别,我都可以获得类似于此图的图形:

enter image description here

在代码底部,您可以看到我已经尝试分割组因子,以便分离条形图,但结果出现错误(KeyError:“group”),我认为这与我使用的多重索引方式有关


2
你能把代码和数据复制到你的问题中吗? - maxymoo
2个回答

38
我会使用seaborn的因子图(factor plot)。
假设你有这样的数据:
import numpy as np
import pandas

import seaborn
seaborn.set(style='ticks') 
np.random.seed(0)

groups = ('Group 1', 'Group 2')
sexes = ('Male', 'Female')
means = ('Low', 'High')
index = pandas.MultiIndex.from_product(
    [groups, sexes, means], 
   names=['Group', 'Sex', 'Mean']
)

values = np.random.randint(low=20, high=100, size=len(index))
data = pandas.DataFrame(data={'val': values}, index=index).reset_index()
print(data)

     Group     Sex  Mean  val
0  Group 1    Male   Low   64
1  Group 1    Male  High   67
2  Group 1  Female   Low   84
3  Group 1  Female  High   87
4  Group 2    Male   Low   87
5  Group 2    Male  High   29
6  Group 2  Female   Low   41
7  Group 2  Female  High   56

你可以使用一个命令来创建因子图,再加上一行额外的代码来删除一些多余的(对于你的数据而言)x轴标签。
# Note: catplot used to be called factorplot
fg = seaborn.catplot(x='Group', y='val', hue='Mean', 
                        col='Sex', data=data, kind='bar')
fg.set_xlabels('')

给我带来的是:

enter image description here


这很完美,谢谢!有没有一种方法可以绘制误差条,其中所表示的误差是平均标准误差? - Simon
@Nem 我现在无法查看任何范围蔓延。但这回答了你最初的问题。至于后续,这个 SO 问题是我在谷歌搜索“seaborn error bars”时得到的第一个结果。https://dev59.com/iGAf5IYBdhLWcg3wOQm2 - Paul H
1
哇,仔细阅读你的代码让我学到了很多关于多索引和绘图的知识,这些以前一直困扰着我。真的非常棒,因为它的简洁易懂! - Mad Physicist
4
关键在于reindex,它会移除多级索引使得(原来的)索引被当作列进行处理。 - user2699
1
请注意,自版本0.9(2018年7月)以来,factorplot已更名为catplot - Michaël
显示剩余3条评论

18

在一个相关问题中,我发现了Stein提供的一种替代方案,它将多级索引编码为不同的标签。以下是您的示例的实现方式:

import pandas as pd
import matplotlib.pyplot as plt
from itertools import groupby
import numpy as np 
%matplotlib inline

groups = ('Group 1', 'Group 2')
sexes = ('Male', 'Female')
means = ('Low', 'High')
index = pd.MultiIndex.from_product(
    [groups, sexes, means], 
   names=['Group', 'Sex', 'Mean']
)

values = np.random.randint(low=20, high=100, size=len(index))
data = pd.DataFrame(data={'val': values}, index=index)
# unstack last level to plot two separate columns
data = data.unstack(level=-1)

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='gray')
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]

def label_group_bar_table(ax, df):
    ypos = -.1
    scale = 1./df.index.size
    for level in range(df.index.nlevels)[::-1]:
        pos = 0
        for label, rpos in label_len(df.index,level):
            lxpos = (pos + .5 * rpos)*scale
            ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos*scale, ypos)
            pos += rpos
        add_line(ax, pos*scale , ypos)
        ypos -= .1

ax = data['val'].plot(kind='bar')
#Below 2 lines remove default labels
ax.set_xticklabels('')
ax.set_xlabel('')
label_group_bar_table(ax, data)

这给出了:

在此输入图片描述


2
plt.Line2D 中,我建议添加 linewidth=0.8color=black 以更好地将线条与绘图框架整合。 - Patrick FitzGerald

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接