Matplotlib散点图图例

96

我创建了一个四维散点图来表示特定区域的不同温度。当我创建图例时,图例显示了正确的符号和颜色,但在上面添加了一条线。我使用的代码是:

colors=['b', 'c', 'y', 'm', 'r']
lo = plt.Line2D(range(10), range(10), marker='x', color=colors[0])
ll = plt.Line2D(range(10), range(10), marker='o', color=colors[0])
l = plt.Line2D(range(10), range(10), marker='o',color=colors[1])
a = plt.Line2D(range(10), range(10), marker='o',color=colors[2])
h = plt.Line2D(range(10), range(10), marker='o',color=colors[3])
hh = plt.Line2D(range(10), range(10), marker='o',color=colors[4])
ho = plt.Line2D(range(10), range(10), marker='x', color=colors[4])
plt.legend((lo,ll,l,a, h, hh, ho),('Low Outlier', 'LoLo','Lo', 'Average', 'Hi', 'HiHi', 'High Outlier'),numpoints=1, loc='lower left', ncol=3, fontsize=8)

我尝试将Line2D更改为ScatterscatterScatter返回错误,而scatter则改变了图表并返回了错误。

对于scatter,我将range(10)更改为包含数据点的列表。每个列表都包含x、y或z变量。

lo = plt.scatter(xLOutlier, yLOutlier, zLOutlier, marker='x', color=colors[0])
ll = plt.scatter(xLoLo, yLoLo, zLoLo, marker='o', color=colors[0])
l = plt.scatter(xLo, yLo, zLo, marker='o',color=colors[1])
a = plt.scatter(xAverage, yAverage, zAverage, marker='o',color=colors[2])
h = plt.scatter(xHi, yHi, zHi, marker='o',color=colors[3])
hh = plt.scatter(xHiHi, yHiHi, zHiHi, marker='o',color=colors[4])
ho = plt.scatter(xHOutlier, yHOutlier, zHOutlier, marker='x', color=colors[4])
plt.legend((lo,ll,l,a, h, hh, ho),('Low Outlier', 'LoLo','Lo', 'Average', 'Hi', 'HiHi',     'High Outlier'),scatterpoints=1, loc='lower left', ncol=3, fontsize=8)
当我运行它时,图例不再存在,只剩下一个小白框在角落里什么也没有。
有任何建议吗?

我相信在这里提供了一个更好的解决方案:链接 - dmvianna
5个回答

160

二维散点图

使用 matplotlib.pyplot 模块的 scatter 方法应该可以(至少在 Python 2.7.5 中使用 matplotlib 1.2.1 的情况下),如下面的示例代码所示。此外,如果您正在使用散点图,请在图例调用中使用 scatterpoints=1 而不是 numpoints=1 ,以使每个图例条目仅有一个点。

在下面的代码中,我使用了随机值而不是反复绘制相同的范围,使所有图形都可见(即不重叠)。

import matplotlib.pyplot as plt
from numpy.random import random

colors = ['b', 'c', 'y', 'm', 'r']

lo = plt.scatter(random(10), random(10), marker='x', color=colors[0])
ll = plt.scatter(random(10), random(10), marker='o', color=colors[0])
l  = plt.scatter(random(10), random(10), marker='o', color=colors[1])
a  = plt.scatter(random(10), random(10), marker='o', color=colors[2])
h  = plt.scatter(random(10), random(10), marker='o', color=colors[3])
hh = plt.scatter(random(10), random(10), marker='o', color=colors[4])
ho = plt.scatter(random(10), random(10), marker='x', color=colors[4])

plt.legend((lo, ll, l, a, h, hh, ho),
           ('Low Outlier', 'LoLo', 'Lo', 'Average', 'Hi', 'HiHi', 'High Outlier'),
           scatterpoints=1,
           loc='lower left',
           ncol=3,
           fontsize=8)

plt.show()

在此输入图片描述

3D散点图

要绘制3D散点图,请使用plot方法,因为图例不支持由 Axes3D 实例的 scatter 方法返回的 Patch3DCollection. 要指定标记样式,可以将其作为位置参数包含在方法调用中,如下面的示例所示。可选地,还可以在linestylemarker参数中都包含参数。

import matplotlib.pyplot as plt
from numpy.random import random
from mpl_toolkits.mplot3d import Axes3D

colors=['b', 'c', 'y', 'm', 'r']

ax = plt.subplot(111, projection='3d')

ax.plot(random(10), random(10), random(10), 'x', color=colors[0], label='Low Outlier')
ax.plot(random(10), random(10), random(10), 'o', color=colors[0], label='LoLo')
ax.plot(random(10), random(10), random(10), 'o', color=colors[1], label='Lo')
ax.plot(random(10), random(10), random(10), 'o', color=colors[2], label='Average')
ax.plot(random(10), random(10), random(10), 'o', color=colors[3], label='Hi')
ax.plot(random(10), random(10), random(10), 'o', color=colors[4], label='HiHi')
ax.plot(random(10), random(10), random(10), 'x', color=colors[4], label='High Outlier')

plt.legend(loc='upper left', numpoints=1, ncol=3, fontsize=8, bbox_to_anchor=(0, 0))

plt.show()

此处输入图片描述


1
为了在绘制3D散点图时使图例正常工作,请使用plot方法并将标记作为位置参数。请参见编辑。 - sodd
我在创建散点图的图例方面遇到了类似的问题(https://stackoverflow.com/questions/60199943/how-to-iterate-a-list-of-list-for-a-scatter-plot),你能帮我吗? - 3kstc
我正在尝试将两个并排绘制在一起,但遇到了问题,有什么指导吗? - Daniel Vilas-Boas

103

如果你使用的是matplotlib 3.1.1或更高版本,可以尝试:

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

x = [1, 3, 4, 6, 7, 9]
y = [0, 0, 5, 8, 8, 8]
classes = ['A', 'B', 'C']
values = [0, 0, 1, 2, 2, 2]
colors = ListedColormap(['r','b','g'])
scatter = plt.scatter(x, y, c=values, cmap=colors)
plt.legend(handles=scatter.legend_elements()[0], labels=classes)

results2


13
2021年最有用的答案。 - mins
1
即使没有cmap,它对我也有效。 - Armin
1
参考:https://matplotlib.org/stable/gallery/lines_bars_and_markers/scatter_with_legend.html#automated-legend-creation - gia huy
4
2022年仍然非常有用。 - trianta2
如果values不是列表索引,而是指向类别的呢?例如,如果values = [10, 10, 20, 30, 30, 30]values = ['A', 'A', 'B', 'C', 'C', 'C'],上述方法就会失败。 - normanius

13

其他答案似乎有点复杂,你可以在scatter函数中添加一个参数'label',这将成为你绘图的图例。

import matplotlib.pyplot as plt
from numpy.random import random

colors = ['b', 'c', 'y', 'm', 'r']

lo = plt.scatter(random(10), random(10), marker='x', color=colors[0],label='Low Outlier')
ll = plt.scatter(random(10), random(10), marker='o', color=colors[0],label='LoLo')
l  = plt.scatter(random(10), random(10), marker='o', color=colors[1],label='Lo')
a  = plt.scatter(random(10), random(10), marker='o', color=colors[2],label='Average')
h  = plt.scatter(random(10), random(10), marker='o', color=colors[3],label='Hi')
hh = plt.scatter(random(10), random(10), marker='o', color=colors[4],label='HiHi')
ho = plt.scatter(random(10), random(10), marker='x', color=colors[4],label='High Outlier')

plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          fancybox=True, shadow=True, ncol=4)

plt.show()

这是您的输出:

img


你假设数据已经按类别分开了。那么就不需要使用“scatter”了,用“plot”也可以。 - Rainald62
你假设数据已经按类别分开了。那么就不需要使用“scatter”了,用“plot”也可以。 - undefined

3
这是更简单的做法(来源:这里):
import matplotlib.pyplot as plt
from numpy.random import rand


fig, ax = plt.subplots()
for color in ['red', 'green', 'blue']:
    n = 750
    x, y = rand(2, n)
    scale = 200.0 * rand(n)
    ax.scatter(x, y, c=color, s=scale, label=color,
               alpha=0.3, edgecolors='none')

ax.legend()
ax.grid(True)

plt.show()

你将获得以下内容:

在此输入图片描述

请查看此处的图例属性。


2

我创建了一个年份的唯一值传说列表,用作散点图中的颜色。散点图变量名为result。result.legend_elements()[0]返回一个颜色列表,我使用labels=legend将颜色映射到值,其中legend是我的年份列表。

legend=[str(year) for year in df['year'].unique()]
plt.title('Battery Capicity kwh')


result = plt.scatter('Approx_Release_price_order_in_K$','Battery_Capacity_kWh',data=df,c='year',label='Class 1')
plt.ylabel('kwh')
plt.xlabel('K$')
plt.legend(handles=result.legend_elements()[0], 
       labels=legend,
       title="Year")
print('The higher priced evs have more battery capacity')

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接