在numpy/scipy中优化旋转掩码的实现

4
这是我第一次在numpy中使用步幅,相对于简单迭代不同的过滤器,它确实提高了速度,但仍然相当缓慢(感觉至少有一两个完全冗余或低效的地方)。所以我的问题是:是否有更好的执行方式或调整我的代码可以使其显着更快?该算法对每个像素执行9个不同过滤器的局部评估,并选择具有最小标准偏差的过滤器(我试图按照图像分析书籍中描述的Nagau和Matsuyma(1980)“复杂区域照片的结构分析”进行实现)。结果是既平滑又锐化的图像(如果您问我,那相当酷!)
import numpy as np
from scipy import ndimage
from numpy.lib import stride_tricks

def get_rotating_kernels():

    kernels = list()

    protokernel = np.arange(9).reshape(3,  3)

    for k in xrange(9):

        ax1, ax2 = np.where(protokernel==k)
        kernel = np.zeros((5,5), dtype=bool)
        kernel[ax1: ax1+3, ax2: ax2+3] = 1
        kernels.append(kernel)

    return kernels


def get_rotation_smooth(im, **kwargs):

    kernels = np.array([k.ravel() for k in get_rotating_kernels()],
                dtype=bool)

    def rotation_matrix(section):

        multi_s = stride_tricks.as_strided(section, shape=(9,25),
            strides=(0, section.itemsize))

        rot_filters = multi_s[kernels].reshape(9,9)

        return rot_filters[rot_filters.std(1).argmin(),:].mean()

    return ndimage.filters.generic_filter(im, rotation_matrix, size=5, **kwargs)

from scipy import lena
im = lena()
im2 = get_rotation_smooth(im)

(仅供评论,由于几乎没有时间花在那里,get_rotating_kernel实际上没有进行优化)

在我的网络本上,它花费了126秒,而Lena毕竟是一张相当小的图像。

编辑:

我得到了建议,将rot_filters.std(1)更改为rot_filters.var(1)可以节省很多平方根,并且它可以节省大约5秒钟的时间。


你尝试过使用性能分析工具(例如cProfile)对其进行分析吗? - nneonneo
其实我还没有,因为rotation_matrix函数被调用了262144次,所以有时间可以节省。而且无论它指向哪个部分,我仍然不知道它如何帮助我...但也许只是因为我还没有学会喜欢cProfile... - deinonychusaur
2个回答

1

我相信你使用Python + scipy 进行优化会遇到困难。然而,我通过使用as_strided 直接生成rot_filters(而不是通过布尔索引)成功地进行了小幅度的改进。这是基于一个非常简单的n维windows函数。(在我意识到scipy中存在2d卷积函数之前,我写了它来解决this problem问题。)以下代码在我的机器上提供了适度的10%加速;请参阅下面的说明以了解其工作原理:

import numpy as np
from scipy import ndimage
from numpy.lib import stride_tricks

# pass in `as_strided` as a default arg to save a global lookup
def rotation_matrix2(section, _as_strided=stride_tricks.as_strided):
    section = section.reshape(5, 5)  # sqrt(section.size), sqrt(section.size)
    windows_shape = (3, 3, 3, 3)     # 5 - 3 + 1, 5 - 3 + 1, 3, 3
    windows_strides = section.strides + section.strides
    windows = _as_strided(section, windows_shape, windows_strides)
    rot_filters = windows.reshape(9, 9)
    return rot_filters[rot_filters.std(1).argmin(),:].mean()

def get_rotation_smooth(im, _rm=rotation_matrix2, **kwargs):
    return ndimage.filters.generic_filter(im, _rm, size=5, **kwargs)

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from scipy.misc import lena
    im = lena()
    im2 = get_rotation_smooth(im)
    #plt.gray()      # Uncomment these lines for
    #plt.imshow(im2) # demo purposes.
    #plt.show()

上述函数rotation_matrix2等价于下面的两个函数(实际上,这两个函数结合起来比你原来的函数略慢,因为windows更加通用)。它完全与您原始代码所做的相同——创建9个3x3的窗口进入一个5x5的数组,然后将它们重新调整形状成9x9的数组进行处理。
def windows(a, w, _as_strided=stride_tricks.as_strided):
    windows_shape = tuple(sa - sw + 1 for sa, sw in zip(a.shape, w))
    windows_shape += w
    windows_strides = a.strides + a.strides
    return _as_strided(a, windows_shape, windows_strides)

def rotation_matrix1(section, _windows=windows):
    rot_filters = windows(section.reshape(5, 5), (3, 3)).reshape(9, 9)
    return rot_filters[rot_filters.std(1).argmin(),:].mean()

Windows可以处理任意维度的数组,只要窗口的维度数量相同。以下是其工作原理的详细说明:

    windows_shape = tuple(sa - sw + 1 for sa, sw in zip(a.shape, w))

我们可以将windows数组视为n-d数组的n-d数组。外部n-d数组的形状由窗口在较大数组中的自由度决定;在每个维度上,窗口可以占据的位置数等于较大数组的长度减去窗口的长度加一。在这种情况下,我们有一个3x3的窗口进入一个5x5的数组,因此外部2-d数组是一个3x3的数组。
    windows_shape += w

内部n-d数组的形状与窗口本身的形状相同。在我们的情况下,这又是一个3x3的数组。

现在来看步幅。我们必须为外部n-d数组和内部n-d数组定义步幅。但事实证明它们是相同的!毕竟,窗口通过较大的数组移动的方式与单个索引移动数组的方式相同,对吧?

    windows_strides = a.strides + a.strides

现在我们已经拥有了创建窗口所需的所有信息:

    return _as_strided(a, windows_shape, windows_strides)

我将windows_strides的定义移到了def之外,并将std(1)更改为var(1)。正如你所指出的,我认为在Python方面没有太多要做的了。 - deinonychusaur

1

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接