Matplotlib绘图非常缓慢

4

请问有谁能帮忙优化在Python中使用plot函数的方法吗?我使用Matplotlib来绘制金融数据。这里是一个用于绘制OHLC数据的小函数。当我添加指标或其他数据时,时间会显著增加。

import numpy as np
import datetime
from matplotlib.collections import LineCollection
from pylab import *
import urllib2

def test_plot(OHLCV):

    bar_width = 1.3
    date_offset = 0.5
    fig = figure(figsize=(50, 20), facecolor='w')
    ax = fig.add_subplot(1, 1, 1)
    labels = ax.get_xmajorticklabels()
    setp(labels, rotation=0)

    month = MonthLocator()
    day   = DayLocator()
    timeFmt = DateFormatter('%Y-%m-%d')

    colormap = OHLCV[:,1] < OHLCV[:,4]
    color = np.zeros(colormap.__len__(), dtype = np.dtype('|S5'))
    color[:] = 'red'
    color[np.where(colormap)] = 'green'
    dates = date2num( OHLCV[:,0])

    lines_hl = LineCollection( zip(zip(dates, OHLCV[:,2]), zip(dates, OHLCV[:,3])))
    lines_hl.set_color(color)
    lines_hl.set_linewidth(bar_width)
    lines_op = LineCollection( zip(zip((np.array(dates) - date_offset).tolist(), OHLCV[:,1]), zip((np.array(dates)).tolist(), parsed_table[:,1])))
    lines_op.set_color(color)
    lines_op.set_linewidth(bar_width)
    lines_cl = LineCollection( zip(zip((np.array(dates) + date_offset).tolist(), OHLCV[:,4]), zip((np.array(dates)).tolist(), parsed_table[:,4])))
    lines_cl.set_color(color)
    lines_cl.set_linewidth(bar_width)
    ax.add_collection(lines_hl,  autolim=True)
    ax.add_collection(lines_cl,  autolim=True)
    ax.add_collection(lines_op,  autolim=True)

    ax.xaxis.set_major_locator(month)
    ax.xaxis.set_major_formatter(timeFmt)
    ax.xaxis.set_minor_locator(day)

    ax.autoscale_view()

    ax.xaxis.grid(True, 'major')
    ax.grid(True)

    ax.set_title('EOD test plot')
    ax.set_xlabel('Date')
    ax.set_ylabel('Price , $')
    fig.savefig('test.png', dpi = 50, bbox_inches='tight')
    close()

if __name__=='__main__':

    data_table = urllib2.urlopen(r"http://ichart.finance.yahoo.com/table.csv?s=IBM&a=00&b=1&c=2012&d=00&e=15&f=2013&g=d&ignore=.csv").readlines()[1:][::-1]
    parsed_table = []
    #Format:  Date, Open, High, Low, Close, Volume
    dtype = (lambda x: datetime.datetime.strptime(x, '%Y-%m-%d').date(),float, float, float, float, int)

    for row in data_table:

        field = row.strip().split(',')[:-1]
        data_tmp = [i(j) for i,j in zip(dtype, field)]
        parsed_table.append(data_tmp)

    parsed_table = np.array(parsed_table)

    import time
    bf = time.time()
    count = 100
    for i in xrange(count):
        test_plot(parsed_table)
    print('Plot time: %s' %(time.time() - bf) / count)

结果如下所示。每个图表的平均执行时间约为2.6秒。在R中绘制图表要快得多,但我没有测量性能,也不想使用Rpy,所以我认为我的代码效率低下。 enter image description here

你在那里使用了很多的zip和拼接,但是没有太多的注释说明它实现了什么功能。作为曾经为了简洁的代码而一直这样做的人,我建议不要这样做。回来后再去审查这些代码会很痛苦... - will
我建议尝试这个。虽然我想用演示来回答这个问题,但我正在用手机回答,那将是一场噩梦... - will
我的建议是通过创建一个“FinanceChart”类来重复使用您的图形,轴线,线条集合和标签。这样,您可以重复使用图表对象100次,从而节省销毁并每次重新创建它的时间。您可以在每个线条集合上使用“set_segments”来仅更改数据。 - John Lyon
1个回答

4
这个解决方案重复使用一个 Figure 实例,并异步保存图表。你可以更改它,让每个处理器都有自己的图表,这样就可以异步进行多个图表绘制,从而进一步提高速度。目前,每个图表需要约 1 秒钟才能完成绘制,相比我的机器上的 2.6 秒已经有了明显提升。
import numpy as np
import datetime
import urllib2
import time
import multiprocessing as mp
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.collections import LineCollection

class AsyncPlotter():
    def __init__(self, processes=mp.cpu_count()):
        self.manager = mp.Manager()
        self.nc = self.manager.Value('i', 0)
        self.pids = []
        self.processes = processes

    def async_plotter(self, nc, fig, filename, processes):
        while nc.value >= processes:
            time.sleep(0.1)
        nc.value += 1
        print "Plotting " + filename
        fig.savefig(filename)
        plt.close(fig)
        nc.value -= 1

    def save(self, fig, filename):
        p = mp.Process(target=self.async_plotter,
                       args=(self.nc, fig, filename, self.processes))
        p.start()
        self.pids.append(p)

    def join(self):
        for p in self.pids:
            p.join()

class FinanceChart():
    def __init__(self, async_plotter):
        self.async_plotter = async_plotter
        self.bar_width = 1.3
        self.date_offset = 0.5
        self.fig = plt.figure(figsize=(50, 20), facecolor='w')
        self.ax = self.fig.add_subplot(1, 1, 1)
        self.labels = self.ax.get_xmajorticklabels()
        setp(self.labels, rotation=0)
        line_hl = LineCollection(([[(734881,1), (734882,5), (734883,9), (734889,5)]]))
        line_op = LineCollection(([[(734881,1), (734882,5), (734883,9), (734889,5)]]))
        line_cl = LineCollection(([[(734881,1), (734882,5), (734883,9), (734889,5)]]))

        self.lines_hl = self.ax.add_collection(line_hl,  autolim=True)
        self.lines_op = self.ax.add_collection(line_cl,  autolim=True)
        self.lines_cl = self.ax.add_collection(line_op,  autolim=True)

        self.ax.set_title('EOD test plot')
        self.ax.set_xlabel('Date')
        self.ax.set_ylabel('Price , $')

        month = MonthLocator()
        day   = DayLocator()
        timeFmt = DateFormatter('%Y-%m-%d')
        self.ax.xaxis.set_major_locator(month)
        self.ax.xaxis.set_major_formatter(timeFmt)
        self.ax.xaxis.set_minor_locator(day)

    def test_plot(self, OHLCV, i):
        colormap = OHLCV[:,1] < OHLCV[:,4]
        color = np.zeros(colormap.__len__(), dtype = np.dtype('|S5'))
        color[:] = 'red'
        color[np.where(colormap)] = 'green'
        dates = date2num( OHLCV[:,0])
        date_array = np.array(dates)
        xmin = min(dates)
        xmax = max(dates)
        ymin = min(OHLCV[:,1])
        ymax = max(OHLCV[:,1])

        self.lines_hl.set_segments( zip(zip(dates, OHLCV[:,2]), zip(dates, OHLCV[:,3])))
        self.lines_hl.set_color(color)
        self.lines_hl.set_linewidth(self.bar_width)
        self.lines_op.set_segments( zip(zip((date_array - self.date_offset).tolist(), OHLCV[:,1]), zip(date_array.tolist(), OHLCV[:,1])))
        self.lines_op.set_color(color)
        self.lines_op.set_linewidth(self.bar_width)
        self.lines_cl.set_segments( zip(zip((date_array + self.date_offset).tolist(), OHLCV[:,4]), zip(date_array.tolist(), OHLCV[:,4])))
        self.lines_cl.set_color(color)
        self.lines_cl.set_linewidth(self.bar_width)

        self.ax.set_xlim(xmin,xmax)
        self.ax.set_ylim(ymin,ymax)

        self.ax.xaxis.grid(True, 'major')
        self.ax.grid(True)
        self.async_plotter.save(self.fig, '%04i.png'%i)

if __name__=='__main__':
    print "Starting"
    data_table = urllib2.urlopen(r"http://ichart.finance.yahoo.com/table.csv?s=IBM&a=00&b=1&c=2012&d=00&e=15&f=2013&g=d&ignore=.csv").readlines()[1:][::-1]
    parsed_table = []
    #Format:  Date, Open, High, Low, Close, Volume
    dtype = (lambda x: datetime.datetime.strptime(x, '%Y-%m-%d').date(),float, float, float, float, int)

    for row in data_table:
        field = row.strip().split(',')[:-1]
        data_tmp = [i(j) for i,j in zip(dtype, field)]
        parsed_table.append(data_tmp)

    parsed_table = np.array(parsed_table)
    import time
    bf = time.time()
    count = 10

    a = AsyncPlotter()
    _chart = FinanceChart(a)

    print "Done with startup tasks"
    for i in xrange(count):
        _chart.test_plot(parsed_table, i)

a.join()
print('Plot time: %.2f' %(float(time.time() - bf) / float(count)))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接