用Matplotlib迭代子图轴数组,转化为单个列表

75

有没有一种简单/干净的方式可以迭代子图返回的轴数组,比如:

nrow = ncol = 2
a = []
fig, axs = plt.subplots(nrows=nrow, ncols=ncol)
for i, row in enumerate(axs):
    for j, ax in enumerate(row):
        a.append(ax)

for i, ax in enumerate(a):
    ax.set_ylabel(str(i))

这甚至适用于 nrowncol == 1

我尝试了列表推导式,例如:

[element for tupl in tupleOfTuples for element in tupl]

但如果 nrowsncols == 1,那么这种方法就失败了。


这个也适用于单轴的解决方案是这个 - mins
6个回答

100
ax 返回的是一个 numpy 数组,可以进行重新形状操作而不需要复制数据。如果使用以下方式,您将获得一个线性数组,可以干净地迭代它。
nrow = 1; ncol = 2;
fig, axs = plt.subplots(nrows=nrow, ncols=ncol)

for ax in axs.reshape(-1): 
  ax.set_ylabel(str(i))

当ncols和nrows都为1时,这个说法就不成立了,因为返回值不是一个数组;为了保持一致性,你可以将返回值转化为一个只有一个元素的数组,但这感觉有些笨拙:

nrow = 1; ncol = 1;
fig, axs = plt.subplots(nrows=nrow, ncols=nrow)
axs = np.array(axs)

for ax in axs.reshape(-1):
  ax.set_ylabel(str(i))

重塑文档-1参数使重塑推断输出的维度。


14
对于nrow=ncol=1,您可以使用squeeze=0。即使两者都为1,plt.subplots(nrows=nrow, ncols=nrow, squeeze=0)始终返回一个包含两个轴的二维数组。 - cronos
这个建议行不通,因为在nrow=ncol=1的情况下,axs不是一个numpy数组。squeeze=0可以解决问题! - Julek

70
plt.subplotsfig返回值包含所有轴的列表。要遍历图中的所有子图,可以使用以下方法:
nrow = 2
ncol = 2
fig, axs = plt.subplots(nrow, ncol)
for i, ax in enumerate(fig.axes):
    ax.set_ylabel(str(i))

对于nrow == ncol == 1,这也适用。


1
那我们就不需要axs了吗? - user2739472
3
这非常简单而且非常有用! - John Mahoney
这应该是被选中的答案。 - mins

23

我不确定何时添加了squeeze关键字参数,但现在已经有了。这可以确保结果始终是一个2D numpy数组。将其转换为1D数组很容易:

fig, ax2d = subplots(2, 2, squeeze=False)
axli = ax2d.flatten()

适用于任意数量的子图,对于单个轴没有技巧,因此比接受的答案要容易一些(也许当时还不存在squeeze)。


你的第二行代码是不是想写成 ax2d.flatten()?否则就不清楚 ax1d 是什么了。 - not link
1
@notlink 哦,那就更有意义了。 - Mark

18
你可以在`axes`对象上使用`numpy`的`.flat`属性。
这里有一个例子。
fig, axes = plt.subplots(2, 3)
for ax in axes.flat:
    ## do something with instance of 'ax'

4
.flat 是numpy数组的属性,与matplotlib无关。 - ImportanceOfBeingErnest
@ImportanceOfBeingErnest 感谢您纠正我的答案。我有些困惑。 - Sukjun Kim

11
TLDR; axes.flat是遍历axes的最Pythonic方式。
正如其他人指出的,plt.subplots()的返回值是一个Axes对象的numpy数组,因此有很多内置的numpy方法可以将数组展平。其中,axes.flat 是最少冗余的访问方法。此外,axes.flatten() 返回数组的一个副本,而 axes.flat 返回数组的一个迭代器,这意味着在长期运行中,axes.flat 将更高效。
借用 @Sukjun-Kim 的例子:
fig, axes = plt.subplots(2, 3)
for ax in axes.flat:
    ## do something with instance of 'ax'

来源: axes.flat 文档 Matplotlib 教程


3

以下是一个好的实践:
例如,我们需要设置一个 4x4 的子图集合,以便我们可以像下面这样使用它们:

rows = 4; cols = 4;
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(20, 16), squeeze=0, sharex=True, sharey=True)
axes = np.array(axes)

for i, ax in enumerate(axes.reshape(-1)):
  ax.set_ylabel(f'Subplot: {i}')

输出结果美观清晰。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接