matplotlib在散点图中没有显示图例。

6

我正试图解决一个聚类问题,为此我需要为我的聚类绘制散点图。

%matplotlib inline
import matplotlib.pyplot as plt
df = pd.merge(dataframe,actual_cluster)
plt.scatter(df['x'], df['y'], c=df['cluster'])
plt.legend()
plt.show()

df['cluster']是实际的聚类编号。因此,我希望它成为我的颜色代码。

enter image description here

它显示了一个图表,但没有显示图例。同时也没有报错。

我做错了什么吗?

2个回答

5

编辑:

生成一些随机数据:

from scipy.cluster.vq import kmeans2
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

n_clusters = 10
df = pd.DataFrame({'x':np.random.randn(1000), 'y':np.random.randn(1000)})
_, df['cluster'] = kmeans2(df, n_clusters)

更新

  • 使用seaborn.relplot函数并设置kind='scatter',或使用seaborn.scatterplot函数。
    • 指定hue='cluster'
# figure level plot
sns.relplot(data=df, x='x', y='y', hue='cluster', palette='tab10', kind='scatter')

enter image description here

# axes level plot
fig, axes = plt.subplots(figsize=(6, 6))
sns.scatterplot(data=df, x='x', y='y', hue='cluster', palette='tab10', ax=axes)
axes.legend(loc='center left', bbox_to_anchor=(1, 0.5))

enter image description here

Original Answer

Plotting (matplotlib v3.3.4):

fig, ax = plt.subplots(figsize=(8, 6))
cmap = plt.cm.get_cmap('jet')
for i, cluster in df.groupby('cluster'):
    _ = ax.scatter(cluster['x'], cluster['y'], color=cmap(i/n_clusters), label=i, ec='k')
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

结果:

在此输入图片描述

解释:

不深入讨论matplotlib内部细节,一次只绘制一个簇似乎能够解决这个问题。 具体而言,ax.scatter()返回一个PathCollection对象,我们在这里明确地丢弃了它,但是它似乎被传递给某种图例处理程序的内部。一次性绘制只会生成一个PathCollection/label对,而逐个绘制每个簇会生成n_clustersPathCollection/label对。您可以通过调用ax.get_legend_handles_labels()来查看这些对象,它返回类似以下内容:

([<matplotlib.collections.PathCollection at 0x7f60c2ff2ac8>,
  <matplotlib.collections.PathCollection at 0x7f60c2ff9d68>,
  <matplotlib.collections.PathCollection at 0x7f60c2ff9390>,
  <matplotlib.collections.PathCollection at 0x7f60c2f802e8>,
  <matplotlib.collections.PathCollection at 0x7f60c2f809b0>,
  <matplotlib.collections.PathCollection at 0x7f60c2ff9908>,
  <matplotlib.collections.PathCollection at 0x7f60c2f85668>,
  <matplotlib.collections.PathCollection at 0x7f60c2f8cc88>,
  <matplotlib.collections.PathCollection at 0x7f60c2f8c748>,
  <matplotlib.collections.PathCollection at 0x7f60c2f92d30>],
 ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])

实际上,ax.legend() 相当于 ax.legend(*ax.get_legend_handles_labels())

注意事项:

  1. 如果使用 Python 2,请确保 i/n_clusters 是一个 float

  2. 省略 fig, ax = plt.subplots() ,并使用 plt.<method> 而不是 ax.<method> ,这样做也可以正常工作,但我总是更喜欢明确地指定我正在使用的 Axes 对象,而不是隐式地使用 "current axes" (plt.gca())。


旧的简单解决方案

如果您可以接受颜色条(而不是离散值标签),则可以使用 Pandas 内置的 Matplotlib 功能:

df.plot.scatter('x', 'y', c='cluster', cmap='jet')

enter image description here


4

这是困扰我很久的问题。现在,我想提供另一个简单的解决方案。我们不必编写任何循环!!!

def vis(ax, df, label, title="visualization"):
    points = ax.scatter(df[:, 0], df[:, 1], c=label, label=label, alpha=0.7)
    ax.set_title(title)
    ax.legend(*points.legend_elements(), title="Classes")

简单而完美的解决方案。 - Dev. R

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接