在“散点/点/蜂群”图中避免数据点的重叠。

63
在使用matplotlib绘制点图时,我希望将重叠的数据点进行偏移,以保持它们全部可见。例如,如果我有:
CategoryA: 0,0,3,0,5  
CategoryB: 5,10,5,5,10  

我希望每个"0"数据点的CategoryA能够并排显示,而不是重叠在一起,同时仍然与CategoryB保持区分。
在R(ggplot2)中,有一个"jitter"选项可以实现这一点。在matplotlib中是否有类似的选项,或者是否有其他方法可以达到类似的效果?
编辑:为了澄清,我想要的实际上是R中的"beeswarm"图,而pybeeswarm是一个早期但有用的matplotlib/Python版本的起点。
编辑:补充说明,Seaborn的Swarmplot(在0.7版本中引入)是我想要的一个很好的实现。

1
点图中,这些点已经按列分开了。 - joaquin
1
“点图”(dot plot)的维基百科定义并不是我想要描述的,但除了“点图”之外,我从未听说过其他术语。它大致上是一个散点图,但具有任意(不一定是数字)的x标签。因此,在我在问题中描述的示例中,将会有一个“CategoryA”的值列,第二列为“CategoryB”,以此类推。(编辑:维基百科对“Cleveland点图”的定义更接近我所寻找的,但仍然不完全相同。) - iayork
类似问题:https://dev59.com/5LTma4cB1Zd3GeqP1yf4 - xApple
7个回答

63

在 @user2467675 的回答基础上进行扩展,以下是我是如何实现的:

def rand_jitter(arr):
    stdev = .01 * (max(arr) - min(arr))
    return arr + np.random.randn(len(arr)) * stdev

def jitter(x, y, s=20, c='b', marker='o', cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=None, hold=None, **kwargs):
    return scatter(rand_jitter(x), rand_jitter(y), s=s, c=c, marker=marker, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths, **kwargs)

stdev 变量确保抖动足以在不同的比例尺上看到,但它假定轴的限制为零和最大值。

然后,您可以调用 jitter 而不是 scatter


我非常喜欢你自动计算抖动比例的功能。对我很有效。 - Chris Warth
如果 arr 只包含零(即 stdev=0),这个代码还能正常工作吗? - Dataman
1
我不得不从jitter()的参数和scatter()的调用中同时删除holdsverts,才能使其在2020年正常工作。希望这能帮助到某些人 :)。 - lx4r

22

Seaborn 通过 sns.swarmplot() 提供了类似直方图的分类点图,以及通过sns.stripplot() 提供了抖动的分类点图:

import seaborn as sns

sns.set(style='ticks', context='talk')
iris = sns.load_dataset('iris')

sns.swarmplot('species', 'sepal_length', data=iris)
sns.despine()

输入图像描述

sns.stripplot('species', 'sepal_length', data=iris, jitter=0.2)
sns.despine()

在此输入图片描述


你的示例不是两个分类变量,而是一个分类变量和一个数值变量(sepal_length)。 - felice
@felice 这个问题要求一个分类变量和一个数值变量。 - joelostblom
即使变量名包含单词“category”。但是我现在明白了我的困惑,谢谢。 - felice

15

我使用了numpy.random在每个类别的固定点周围将数据沿X轴进行“散点/蜂群”分布,然后基本上为每个类别执行pyplot.scatter():

import matplotlib.pyplot as plt
import numpy as np

#random data for category A, B, with B "taller"
yA, yB = np.random.randn(100), 5.0+np.random.randn(1000)

xA, xB = np.random.normal(1, 0.1, len(yA)), 
         np.random.normal(3, 0.1, len(yB))

plt.scatter(xA, yA)
plt.scatter(xB, yB)
plt.show()

散点图


8

一种解决问题的方法是将散点图/点图/蜜蜂图中的每一行视为直方图中的一个条形箱:

data = np.random.randn(100)

width = 0.8     # the maximum width of each 'row' in the scatter plot
xpos = 0        # the centre position of the scatter plot in x

counts, edges = np.histogram(data, bins=20)

centres = (edges[:-1] + edges[1:]) / 2.
yvals = centres.repeat(counts)

max_offset = width / counts.max()
offsets = np.hstack((np.arange(cc) - 0.5 * (cc - 1)) for cc in counts)
xvals = xpos + (offsets * max_offset)

fig, ax = plt.subplots(1, 1)
ax.scatter(xvals, yvals, s=30, c='b')

这显然涉及到数据分组,因此可能会失去一些精度。如果您有离散数据,您可以替换:

counts, edges = np.histogram(data, bins=20)
centres = (edges[:-1] + edges[1:]) / 2.

使用:

centres, counts = np.unique(data, return_counts=True)

一种保留连续数据精确y坐标的替代方法是使用核密度估计来缩放x轴上的随机抖动振幅:

from scipy.stats import gaussian_kde

kde = gaussian_kde(data)
density = kde(data)     # estimate the local density at each datapoint

# generate some random jitter between 0 and 1
jitter = np.random.rand(*data.shape) - 0.5 

# scale the jitter by the KDE estimate and add it to the centre x-coordinate
xvals = 1 + (density * jitter * width * 2)

ax.scatter(xvals, data, s=30, c='g')
for sp in ['top', 'bottom', 'right']:
    ax.spines[sp].set_visible(False)
ax.tick_params(top=False, bottom=False, right=False)

ax.set_xticks([0, 1])
ax.set_xticklabels(['Histogram', 'KDE'], fontsize='x-large')
fig.tight_layout()

这种第二种方法基于小提琴图的工作原理。它仍然不能保证没有任何点重叠,但我发现实际应用中只要有足够数量的点(>20),并且分布可以被合理地近似为高斯函数之和,通常会得到相当漂亮的结果。

enter image description here


不幸的是,在 xvals = 1 + (density * jitter * width * 2) 部分中的 2 是一个参数,必须根据数据集进行调整。对于我的数据,我不得不将其设置为2000才能看到任何抖动,并将其设置为20,000以在最密集的区域获得良好的离散度。 - Aaron Bramson

7

如果您不知道这里是否有直接的mpl替代方案,那么我给出一个非常基本的建议:

from matplotlib import pyplot as plt
from itertools import groupby

CA = [0,4,0,3,0,5]  
CB = [0,0,4,4,2,2,2,2,3,0,5]  

x = []
y = []
for indx, klass in enumerate([CA, CB]):
    klass = groupby(sorted(klass))
    for item, objt in klass:
        objt = list(objt)
        points = len(objt)
        pos = 1 + indx + (1 - points) / 50.
        for item in objt:
            x.append(pos)
            y.append(item)
            pos += 0.04

plt.plot(x, y, 'o')
plt.xlim((0,3))

plt.show()

enter image description here


6

Seaborn的swarmplot看起来是最符合您设想的,但您也可以使用Seaborn的regplot进行抖动:

import seaborn as sns
iris = sns.load_dataset('iris')

sns.swarmplot('species', 'sepal_length', data=iris)

sns.regplot(x='sepal_length',
            y='sepal_width',
            data=iris,
            fit_reg=False,  # do not fit a regression line
            x_jitter=0.1,  # could also dynamically set this with range of data
            y_jitter=0.1,
            scatter_kws={'alpha': 0.5})  # set transparency to 50%

5

扩展@wordsforthewise的答案(抱歉,由于声誉问题无法发表评论),如果您需要同时使用抖动和色调将点按某些类别着色(例如我所做的),则Seaborn的lmplot是一个很好的选择,而不是reglpot:

import seaborn as sns
iris = sns.load_dataset('iris')
sns.lmplot(x='sepal_length', y='sepal_width', hue='species', data=iris, fit_reg=False, x_jitter=0.1, y_jitter=0.1)  

如果您想在现有答案中添加内容,可以进行编辑。如果该答案已经足够好了,您也可以添加另一个答案来扩展它。 ;) - Mohammad Reza Shahrestani
x_jitter 最近已被弃用。 - skjerns

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接