matplotlib中的axes.flat是什么作用?

40

我看到过很多使用matplotlib的程序使用axes.flat函数,就像这段代码:

for i, ax in enumerate(axes.flat):

这是做什么用的?

2个回答

74

让我们来看一个最简单的例子,我们使用plt.subplots创建一些坐标轴,也可以参考这个问题

import matplotlib.pyplot as plt

fig, axes = plt.subplots(ncols=2,nrows=3, sharex=True, sharey=True)

for i, ax in enumerate(axes.flat):
    ax.scatter([i//2+1, i],[i,i//3])
    
plt.show()

在这里,axes是一个包含轴的numpy数组。

print(type(axes))
> <type 'numpy.ndarray'>
print(axes.shape)
> (3L, 2L)

axes.flat 不是一个函数,它是 numpy.ndarray 的一个属性:numpy.ndarray.flat

ndarray.flat 是数组的一个一维迭代器。
这是一个 numpy.flatiter 实例,类似于 Python 内置的迭代器对象,但不是其子类。

示例:

import numpy as np

a = np.array([[2,3],
              [4,5],
              [6,7]])
    
for i in a.flat:
    print(i)

它将输出数字2 3 4 5 6 7

作为数组的迭代器,您可以使用它来循环遍历所有来自三行两列轴数组的轴。

for i, ax in enumerate(axes.flat):
在每次迭代中,它会从该数组中产生下一个轴,以便您可以在单个循环中轻松地绘制所有轴。另一种方法是使用`axes.flatten()`,其中`flatten()`是numpy数组的方法。它返回数组的扁平化版本,而不是迭代器:
for i, ax in enumerate(axes.flatten()):

从外部看,这两者没有任何区别。然而,迭代器实际上不会创建一个新的数组,因此可能会稍微快一些(尽管在matplotlib axes对象的情况下几乎不会注意到这一点)。

flat1 = [ax for ax in axes.flat]
flat2 = axes.flatten()
print(flat1 == flat2)
> [ True  True  True  True  True  True]

使用展平版本的轴数组进行迭代的优点是,与单独迭代行和列的天真方法相比,您将节省一个循环。

for row in axes:
    for ax in row:
        ax.scatter(...)

1
fig, ax = plt.subplots(3, 3, figsize=())
ax = ax.flatten()
for i, col in enumerate(columns):
    sns.distplot(d2[col], ax=ax[i])
plt.tight_layout()

5
请将文本翻译成中文。仅返回翻译后的文本,并附加说明。 - Sabito stands with Ukraine

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接