Matplotlib返回一个绘图对象

51

我有一个函数可以包装 pyplot.plt,因此我可以快速创建具有经常使用的默认值的图表:

def plot_signal(time, signal, title='', xlab='', ylab='',
                line_width=1, alpha=1, color='k',
                subplots=False, show_grid=True, fig_size=(10, 5)):

    # Skipping a lot of other complexity here

    f, axarr = plt.subplots(figsize=fig_size)
    axarr.plot(time, signal, linewidth=line_width,
               alpha=alpha, color=color)
    axarr.set_xlim(min(time), max(time))
    axarr.set_xlabel(xlab)
    axarr.set_ylabel(ylab)
    axarr.grid(show_grid)

    plt.suptitle(title, size=16)
    plt.show()

然而,有时我想要能够返回绘图,以便可以手动添加/编辑特定图形的内容。例如,我想能够更改轴标签或在调用函数后向绘图添加第二条线:

import numpy as np

x = np.random.rand(100)
y = np.random.rand(100)

plot = plot_signal(np.arange(len(x)), x)

plot.plt(y, 'r')
plot.show()

我看到了一些关于这个问题的问题(如何从Pandas plot函数返回一个matplotlib.figure.Figure对象?AttributeError: 'Figure' 对象没有 'plot' 属性),因此我尝试在函数结尾添加以下内容:

  • return axarr

  • return axarr.get_figure()

  • return plt.axes()

然而,它们都会返回类似的错误:AttributeError: 'AxesSubplot' object has no attribute 'plt'

什么是正确的方法来返回一个绘图对象以便稍后进行编辑?


2
你尝试过返回 fig = plt.gcf() 吗? - b-fg
这里的一切都是正确的,除了调用 plot.plt() 而不是 plot.plot()。愚蠢的错误;任何人都可能犯错 :) - Liran Funaro
3个回答

46

我认为这个错误已经很清楚地说明了。没有类似于pyplot.plt的东西。当被导入时,pltpyplot的准标准缩写形式,即import matplotlib.pyplot as plt

关于问题,第一种方法return axarr最为通用。你可以获取一个轴,或者一个轴数组,并可以对其进行绘图。

代码可能如下所示:

def plot_signal(x,y, ..., **kwargs):
    # Skipping a lot of other complexity here
    f, ax = plt.subplots(figsize=fig_size)
    ax.plot(x,y, ...)
    # further stuff
    return ax

ax = plot_signal(x,y, ...)
ax.plot(x2, y2, ...)
plt.show()

30

这其实是一个很好的问题,我花了几年时间才弄明白。一个很好的方法是将一个图形对象传递给你的代码,并让你的函数添加一个轴,然后返回更新后的图形。以下是一个示例:

fig_size = (10, 5)
f = plt.figure(figsize=fig_size)

def plot_signal(time, signal, title='', xlab='', ylab='',
                line_width=1, alpha=1, color='k',
                subplots=False, show_grid=True, fig=f):

    # Skipping a lot of other complexity here

    axarr = f.add_subplot(1,1,1) # here is where you add the subplot to f
    plt.plot(time, signal, linewidth=line_width,
               alpha=alpha, color=color)
    plt.set_xlim(min(time), max(time))
    plt.set_xlabel(xlab)
    plt.set_ylabel(ylab)
    plt.grid(show_grid)
    plt.title(title, size=16)
    
    return(f)
f = plot_signal(time, signal, fig=f)
f

7

根据matplotlib文档,建议使用的签名为:

def my_plotter(ax, data1, data2, param_dict):
    """
    A helper function to make a graph

    Parameters
    ----------
    ax : Axes
        The axes to draw to

    data1 : array
       The x data

    data2 : array
       The y data

    param_dict : dict
       Dictionary of keyword arguments to pass to ax.plot

    Returns
    -------
    out : list
        list of artists added
    """
    out = ax.plot(data1, data2, **param_dict)
    
    return out

这可以用作:

data1, data2, data3, data4 = np.random.randn(4, 100)
fig, ax = plt.subplots(1, 1)
my_plotter(ax, data1, data2, {'marker': 'x'})

您应该传递轴而非图表。一个Figure包含一个或多个AxesAxes是可以指定x-y格式、3d绘图等的区域。图表是我们用来绘制数据的东西——它可以是Jupyter笔记本或Windows GUI等。

1
如果您的函数生成由多个“Axes”组成的“Figure”,那该怎么办?如果您希望用户能够在之后修改该图形,您该如何处理? - cbrnr

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接