如何按类别创建散点图

114
我正在尝试使用Pandas DataFrame对象在pyplot中制作一个简单的散点图,但希望找到一种有效的方法来绘制两个变量,但符号由第三列(key)指定。 我已经尝试过使用df.groupby的各种方法,但都没有成功。下面是一个样本df脚本。这将根据“key1”为标记着色,但我想看到一个带有“key1”类别的图例。我接近了吗?谢谢。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()
8个回答

141

你可以使用scatter进行此操作,但这需要您的key1具有数字值,并且您将没有图例,正如您所注意到的那样。

对于离散类别,最好只使用plot。 例如:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()

plt.show()

enter image description here

如果你想让东西看起来像是默认的pandas样式,那么只需使用pandas样式表更新rcParams并使用其颜色生成器即可。 (我还稍微调整了图例):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')

fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')

plt.show()

这里输入图片描述


在上面的RGB示例中,为什么符号在图例中显示了两次?如何只显示一次? - Steve Schulist
1
使用ax.legend(numpoints=1)来仅显示一个标记。通常使用Line2D时有两个标记,并且它们之间经常连接一条线。 - Joe Kington
ax.plot()命令之后添加plt.hold(True)后,这段代码才对我起作用。 有任何想法为什么会这样吗? - Yuval Atzmon
set_color_cycle()在matplotlib 1.5中已被弃用。现在有set_prop_cycle()可用。 - a.l.e
非常反直觉,但谢谢! - Stepan S. Sushko

69

通过使用Seabornpip install seaborn)作为一行代码,这很容易实现。

sns.scatterplot(x_vars="one", y_vars="two", data=df, hue="key1")

import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)

df = pd.DataFrame(
    np.random.normal(10, 1, 30).reshape(10, 3),
    index=pd.date_range('2010-01-01', freq='M', periods=10),
    columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)

sns.scatterplot(x="one", y="two", data=df, hue="key1")

enter image description here

这是参考数据框:

enter image description here

由于您的数据中有三个变量列,您可能希望使用以下方法绘制所有配对维数:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1")

输入图像描述

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/是另一个选择。


19

使用plt.scatter,我只能想到一种方法:使用代理艺术家:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)

ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)

结果是:

在此输入图片描述


13
你可以使用df.plot.scatter,并将数组传递给c=参数来定义每个点的颜色。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()

输入图像描述


9
从 matplotlib 3.1 开始,您可以使用 .legend_elements()。示例显示在 自动创建图例 中。优点是可以使用单个散点调用。
在这种情况下:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)


fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()

enter image description here

如果键不是直接以数字给出的话,它看起来会像这样
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")

labels, index = np.unique(df["key1"], return_inverse=True)

fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()

enter image description here


我得到了一个错误,说“PathCollection”对象没有属性“legends_elements”。我的代码如下。fig,ax = plt.subplots(1,1,figsize =(4,4)) scat = ax.scatter(rand_jitter(important_dataframe [“workout_type_int”],jitter = 0.04),                 important_dataframe [“distance”],c = color_list,marker ='o',alpha = 0.9) print(scat.legends_elements()) #ax.legend(* scat.legend_elements()) - Nandish Patel
1
@NandishPatel,请检查此答案的第一句话。同时确保不要混淆legends_elementslegend_elements - ImportanceOfBeingErnest
是的,谢谢。那是一个笔误(legends/legend)。我已经在做某件事情了六个小时,所以Matplotlib版本没有出现在我的脑海中。我以为我正在使用最新的版本。我很困惑,因为文档上说有这样的方法,但代码却报错了。再次感谢您。我现在可以睡觉了。 - Nandish Patel

5

您还可以尝试使用专注于声明式可视化的 Altairggplot

import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

Altair 代码

from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')

这里插入图片描述

ggplot 代码

from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()

enter image description here


3

虽然有些hacky,但您可以使用one1作为Float64Index来一次性完成所有操作:

df.set_index('one').sort_index().groupby('key1')['two'].plot(style='--o', legend=True)

enter image description here

请注意,从0.20.3版本开始,需要对索引进行排序,而图例显示可能有些奇怪

1

Seaborn有一个包装函数scatterplot,可以更高效地完成此操作。

sns.scatterplot(data = df, x = 'one', y = 'two', data =  'key1'])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接