加速我的交互式matplotlib图表

3
我正在将一个jpeg图像(大型:1300x2000)加载到matplotlib中,在其上绘制一个50x50的方格,并单击每个方格进行着色编码。然而,我注意到程序在我的点击之后远远落后,并且如果我以合理的速度快速点击50个方格,则需要长达30秒才能赶上。我想知道是否有人可以加速。以下是我的脚本,如果您复制/粘贴它并具有scipy,numpy,matplotlib,pillow和tkinter,则可以立即使用。

欢迎任何建议。我是一名医学科学家,所以请原谅我如果代码没有很好地解释:

import matplotlib
import matplotlib.pyplot as plt
import tkinter
import tkinter.filedialog
from matplotlib.figure import  Figure
import math, sys
import numpy as np
import scipy.io as sio
from PIL import Image
from numpy import arange, sin, pi
#from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
#matplotlib.matplotlib_fname()
import os, re

global stridesize, classnumber,x, im,fn, plt, fig, mask

classnumber = 1



def onmove(eve):
    global x,im, plt
    print(eve.ydata)
    print(eve.button)
    if (eve.ydata !=None) and (eve.xdata !=None):
        if eve.button==1:
            print(eve.button)
            xcoord = int(eve.xdata)
            ycoord = int(eve.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
           # print(eve.xdata, int(eve.ydata), stridesize)
            if(classnumber==1):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,0])
            if(classnumber==2):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,0])
            if(classnumber==3):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,255])
            if(classnumber==4):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,255])
            if(classnumber==5):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,255,0])
            if(classnumber==6):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,255])
            if(classnumber==7):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([100,255,50])


        if eve.button==3:
            xcoord = int(eve.xdata)
            ycoord = int(eve.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
            print(eve.xdata, int(eve.ydata), stridesize)
            mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,0])
        im.set_data(mask)
        fig.canvas.draw()




def onclick(event):

    if (event.ydata !=None) and (event.xdata !=None):
        global x, im, fig
        if event.button==1:
            xcoord = int(event.xdata)
            ycoord = int(event.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
            print(event.xdata, int(event.ydata), stridesize)
            if(classnumber==1):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,0])
            if(classnumber==2):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,0])
            if(classnumber==3):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,255])
            if(classnumber==4):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,255])
            if(classnumber==5):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,255,0])
            if(classnumber==6):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,255])
            if(classnumber==7):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([100,255,50])
            im.set_data(mask)

        if event.button==3:
            xcoord = int(event.xdata)
            ycoord = int(event.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
            print(event.xdata, int(event.ydata), stridesize)
            mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,0])
            im.set_data(mask)
        fig.canvas.draw()


def onpress(event):
    global classnumber, mask
    if (event.key == 'e'):
       print("YO")
       mask[:,:,:]=0;
       im.set_data(mask)
       fig.canvas.draw()

    if (event.key=='s'):
        savemask(fn)
    if (event.key=='r'):
        plt.figure();
        plt.imshow(mask);
        plt.show();
    if int(event.key) > 0 and int(event.key) <9 :
       classnumber = int(event.key)
       print(classnumber)


def onrelease(event):
    print(event.button)
 #   im.set_data(mask)




def savemask(fn):
    # matrixname =os.path.basename(filename)
    # matrixname = re.sub(r'\.jpg','',matrixname)
    pre, ext = os.path.splitext(fn)
    savename_default = os.path.basename(pre)
    options = {}
    options['defaultextension'] = ''
    options['filetypes'] = [('mat files', '.mat')]
    options['initialdir'] = ''
    options['initialfile'] = savename_default
    options['title'] = 'Save file'


    f = tkinter.filedialog.asksaveasfile(**options)
    if f is None: # asksaveasfile return `None` if aadialog closed with "cancel".
        return
    name = f.name
    sio.savemat(name,{'mask':mask},do_compression=True)
    f.close()




root = tkinter.Tk()
root.withdraw()

options = {}

options['defaultextension'] = '.jpg'

options['filetypes'] = [('Jpeg', '.jpg')]

options['initialdir'] = 'C:\\'
options['initialfile']= ''
options['parent'] = root

options['title'] = 'This is a title'


fn= tkinter.filedialog.askopenfilename(**options)


img = Image.open(fn)
x = np.asarray(img)
x.setflags(write=1)
#masksize= (x.shape[0],x.shape[1],4)
mask= np.zeros(x.shape,'uint8')
#mask[:,:,3]=0.2
fig = plt.figure()
fig.suptitle(r'Key codes: 1 = Tumour, 2 = stroma-hypocellular, 3=stroma cellular (inflammatory)' '\n4 = proteinaceous, 5= red cells, 6,7: anyother,''\nRight click: clear square''\n r:  review mask, e: erase mask, o : open mask image, s : save mask image;')


im=plt.imshow(x)
im=plt.imshow(mask,alpha=.25)
ax = plt.gca();

stridesize = 50;

plt.rcParams['keymap.save']=''
ax.set_yticks(np.arange(0, x.shape[0], stridesize));
ax.set_xticks(np.arange(0, x.shape[1], stridesize));

cid = fig.canvas.mpl_connect('button_press_event', onclick)
cod = fig.canvas.mpl_connect('key_press_event', onpress)
#cdd = fig.canvas.mpl_connect('motion_notify_event', onmove)
cdr = fig.canvas.mpl_connect('button_release_event', onrelease)

plt.grid(b=True, which='both', color='black',linestyle='-')
#
plt.show()

plt.ion()

对于那些想知道上述代码基本错误的人来说,更改一行代码使速度翻倍。因此,我将每个调用更改为:fig.canvas.draw() 到 fig.canvas.draw_idle()。 - Maelstorm
1个回答

3

首先,我建议尽可能避免使用全局变量。您可以使用class来代替它。下面是您的代码意图的完整工作摘要:

import numpy as np

import matplotlib
matplotlib.use('Qt4Agg')

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

class ColorCode(object):

    def __init__(self, block_size=(50,50), colors=['red', 'green', 'blue'], alpha=0.3):
        self.by, self.bx = block_size # block size
        self.selected = 0 # selected color
        self.colors = colors
        self.cmap = ListedColormap(colors) # color map for labels
        self.mask = None # annotation mask
        self.alpha = alpha
        # Plots
        self.fig = plt.figure()
        self.ax = self.fig.gca()
        # Events
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)
        self.fig.canvas.mpl_connect('key_press_event', self.on_key)

    def color_code(self, img):
        self.imshape = img.shape[:2]
        self.mask = np.full(img.shape[:2], -1, np.int32) # masked labels
        self.ax.imshow(img) # show image
        self.ax.imshow(np.ma.masked_where(self.mask < 0, self.mask), cmap=self.cmap,
                       alpha=self.alpha, vmin=0, vmax=len(self.colors)) # show mask
        # Run
        plt.show(block=True)
        return self.mask

    def on_click(self, event):
        if not event.inaxes or self.mask is None:
            return
        # Get corresponding coordinates
        py, px = int(event.ydata), int(event.xdata)
        cy, cx = py//self.by, px//self.bx # grid coordinates
        ymin = cy * self.by
        ymax = min((cy+1) * self.by, self.imshape[0])
        xmin = cx * self.bx
        xmax = min((cx+1) * self.bx, self.imshape[1])
        # Update mask
        if event.button == 1:
            self.mask[ymin:ymax, xmin:xmax] = self.selected
        elif event.button == 3:
            self.mask[ymin:ymax, xmin:xmax] = -1
        # Update figure
        self.ax.images[1].set_data(np.ma.masked_where(self.mask < 0, self.mask))
        self.fig.canvas.draw_idle()

    def on_key(self, event):
        ikey = int(event.key)
        if 0 <= ikey < len(self.colors):
            self.selected = ikey

您的代码与以下主要不同:
  1. 它不使用全局变量,而是使用类变量。使其更安全地运行并更易于扩展/修改。

  2. 注释不是将颜色应用于三维 mask,而是将注释保存为二维 mask,其中每个像素的值在范围 [1,len(colors)] 中,指示它属于哪种颜色。然后通过使用 ListedColormap 在图中添加颜色来设置自定义颜色映射。

  3. 它绘制图像,并在其上叠加分割掩模。最初,掩模填充为 -1,表示它没有标签。通过使用 numpy 的 masked array,您可以将 mask < 0 的位置掩盖以便于在绘图中不显示,而在其他位置上则以彩色方式呈现。

  4. 提供了可能颜色的列表作为类的参数。它将允许您从 0 到 len(colors) 选择颜色,最多可使用 10 种颜色(因为它当前绑定到键盘中的数字)。

  5. fig.canvas.draw_idlefig.canvas.draw 好得多。后者会阻塞程序,直到绘制完成。

  6. 由于所有内容都在类中,因此代码看起来更加清晰简洁。

您可以按以下方式调用代码:

>>> random_image = np.random.randn(1000,2000, 3)
>>> result = ColorCode().color_code(random_image)

并且result会包含标记为mask的标签,其中每个像素都有一个数字表示用哪种颜色进行标记(如果没有则为-1)。最后,可以将其他参数传递给ColorCode的构造函数,例如block_size =(100,100)以获取不同的块大小,alpha = 0.5以减少掩膜的不透明度(或将alpha = 1设置为None)。

希望对您有所帮助,或者至少可以从中获取一些想法。


非常感谢。速度至少提高了两倍。 - Maelstorm
@Maelstorm编辑了问题,添加了左键单击右键单击之间的区别。现在您可以使用右键单击擦除。此外,我认为如果您不是用单击+重绘的方式,而是使用按住+移动鼠标并保存路径+松开+重绘的方式,可能会大大提高效率,这样您就可以在按下按钮的同时拖动鼠标浏览所有有趣的区域,并且当释放时,它们都将被涂成这种颜色。但编码会略微困难一些。 - Imanol Luengo
能否使用其他后端?我安装了Py3和PyQt5,但是无论如何更改matplotlib.use(),我仍然会遇到PyQt4错误。 - percusse
1
@percusse 只需使用 matplotlib.use('Qt5Agg') 就可以正常工作。如果不行,尝试添加 force=True,如 matplotlib.use('Qt5Agg', force=True)。但是,上述代码中的所有内容在我这里都可以在 python2 和 python3 中正常工作。 - Imanol Luengo

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接