如何创建一个奇数个子图

26
我正在尝试创建一个绘图函数,该函数以所需的图表数量作为输入,并使用pylab.subplotssharex=True选项进行绘制。如果所需的图表数量是奇数,则我想移除最后一个面板,并强制在其上方的面板上显示刻度标签。我找不到同时进行这两个操作并使用sharex=True选项的方法。子图的数量可能相当大(>20)。
以下是示例代码。在此示例中,当i=3时,我希望强制显示xtick标签。
import numpy as np
import matplotlib.pylab as plt

def main():
    n = 5
    nx = 100
    x = np.arange(nx)
    if n % 2 == 0:
        f, axs = plt.subplots(n/2, 2, sharex=True)
    else:
        f, axs = plt.subplots(n/2+1, 2, sharex=True)
    for i in range(n):
        y = np.random.rand(nx)
        if i % 2 == 0:
            axs[i/2, 0].plot(x, y, '-', label='plot '+str(i+1))
            axs[i/2, 0].legend()
        else:
            axs[i/2, 1].plot(x, y, '-', label='plot '+str(i+1))
            axs[i/2, 1].legend()
    if n % 2 != 0:
        f.delaxes(axs[i/2, 1])
    f.show()


if __name__ == "__main__":
     main()
5个回答

48
简单来说,你让子图调用偶数(在这种情况下是6个子图):
f, ax = plt.subplots(3, 2, figsize=(12, 15))

然后你删除你不需要的那一个:

f.delaxes(ax[2,1]) # The indexing is zero-based here

这个问题和回答是以自动化的方式来看待它,但我认为在这里发布基本用例是值得的。


18

如果你将你的main函数中的最后一个if替换为以下内容:

if n % 2 != 0:
    for l in axs[i/2-1,1].get_xaxis().get_majorticklabels():
        l.set_visible(True)
    f.delaxes(axs[i/2, 1])

f.show()

它应该能起作用:

绘图


当处理大量轴对象(子图)时,使用delaxes似乎效率低下。我最终使用add_subplot来完成这个任务。但我不确定如何获取subplots中可用的sharexsharey - CMCDragonkai

5

我经常生成任意数量的子图(有时数据会导致3个子图,有时是13个等),我编写了一个小型实用程序函数来避免思考。

我定义的两个函数如下。您可以更改样式选择以匹配您的偏好。

import math
import numpy as np
from matplotlib import pyplot as plt


def choose_subplot_dimensions(k):
    if k < 4:
        return k, 1
    elif k < 11:
        return math.ceil(k/2), 2
    else:
        # I've chosen to have a maximum of 3 columns
        return math.ceil(k/3), 3


def generate_subplots(k, row_wise=False):
    nrow, ncol = choose_subplot_dimensions(k)
    # Choose your share X and share Y parameters as you wish:
    figure, axes = plt.subplots(nrow, ncol,
                                sharex=True,
                                sharey=False)

    # Check if it's an array. If there's only one plot, it's just an Axes obj
    if not isinstance(axes, np.ndarray):
        return figure, [axes]
    else:
        # Choose the traversal you'd like: 'F' is col-wise, 'C' is row-wise
        axes = axes.flatten(order=('C' if row_wise else 'F'))

        # Delete any unused axes from the figure, so that they don't show
        # blank x- and y-axis lines
        for idx, ax in enumerate(axes[k:]):
            figure.delaxes(ax)

            # Turn ticks on for the last ax in each column, wherever it lands
            idx_to_turn_on_ticks = idx + k - ncol if row_wise else idx + k - 1
            for tk in axes[idx_to_turn_on_ticks].get_xticklabels():
                tk.set_visible(True)

        axes = axes[:k]
        return figure, axes

以下是使用13个子图的示例:

x_variable = list(range(-5, 6))
parameters = list(range(0, 13))

figure, axes = generate_subplots(len(parameters), row_wise=True)
for parameter, ax in zip(parameters, axes):
    ax.plot(x_variable, [x**parameter for x in x_variable])
    ax.set_title(label="y=x^{}".format(parameter))

plt.tight_layout()
plt.show()

这将生成以下内容:

进入图像描述

或者,切换到列优先遍历顺序 (generate_subplots(..., row_wise=False)) 会生成:

进入图像描述


3

与其通过计算来检测需要删除的子图,不如查看哪个子图上没有打印任何内容。您可以参考此答案中提供的各种检查轴上是否绘制了内容的方法。使用函数ax.has_Data(),您可以像这样简化您的函数:

def main():
    n = 5
    max_width = 2 ##images per row
    height, width = n//max_width +1, max_width
    fig, axs = plt.subplots(height, width, sharex=True)

    for i in range(n):
        nx = 100
        x = np.arange(nx)
        y = np.random.rand(nx)
        ax = axs.flat[i]
        ax.plot(x, y, '-', label='plot '+str(i+1))
        ax.legend(loc="upper right")

    ## access each axes object via axs.flat
    for ax in axs.flat:
        ## check if something was plotted 
        if not bool(ax.has_data()):
            fig.delaxes(ax) ## delete if nothing is plotted in the axes obj

    fig.show()

您还可以使用n参数指定要显示的图像数量,并使用max_width参数指定每行所需的最大图像数。

输入图像描述


非常好的答案。运行良好,适用于子图数量未知的动态环境。谢谢。 - Wrichik Basu

-4

对于 Python 3,您可以按如下方式删除:

# I have 5 plots that i want to show in 2 rows. So I do 3 columns. That way i have 6 plots.
f, axes = plt.subplots(2, 3, figsize=(20, 10))

sns.countplot(sales_data['Gender'], order = sales_data['Gender'].value_counts().index, palette = "Set1", ax = axes[0,0])
sns.countplot(sales_data['Age'], order = sales_data['Age'].value_counts().index, palette = "Set1", ax = axes[0,1])
sns.countplot(sales_data['Occupation'], order = sales_data['Occupation'].value_counts().index, palette = "Set1", ax = axes[0,2])
sns.countplot(sales_data['City_Category'], order = sales_data['City_Category'].value_counts().index, palette = "Set1", ax = axes[1,0])
sns.countplot(sales_data['Marital_Status'], order = sales_data['Marital_Status'].value_counts().index, palette = "Set1", ax = axes[1, 1])

# This line will delete the last empty plot
f.delaxes(ax= axes[1,2]) 

1
这个答案完全没有用:代码引用了一个外部模块(seaborn),并包含一些外部数据(性别、年龄等),读者无法访问。因此,这基本上是其他人提出的解决方案的转载,没有任何改进或其他原创工作。 - Aelius

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接