提高代码效率:滑动窗口上的标准差

23

我正在尝试改进一个函数,该函数计算图像中每个像素周围像素的标准差。我的函数使用两个嵌套循环来遍历矩阵,这是程序的瓶颈。我猜测可以通过利用numpy来消除循环以提高效率,但我不知道如何操作。

欢迎任何建议!

敬礼

def sliding_std_dev(image_original,radius=5) :
    height, width = image_original.shape
    result = np.zeros_like(image_original) # initialize the output matrix
    hgt = range(radius,height-radius)
    wdt = range(radius,width-radius)
    for i in hgt:
        for j in wdt:
            result[i,j] = np.std(image_original[i-radius:i+radius,j-radius:j+radius])
    return result

你能给出图像和半径的大致尺寸吗? - tom10
5个回答

39

酷技巧:通过给定窗口内平方和与总和,可以计算标准差。

因此,您可以使用数据上的均匀滤波器非常快地计算标准差:

from scipy.ndimage.filters import uniform_filter

def window_stdev(arr, radius):
    c1 = uniform_filter(arr, radius*2, mode='constant', origin=-radius)
    c2 = uniform_filter(arr*arr, radius*2, mode='constant', origin=-radius)
    return ((c2 - c1*c1)**.5)[:-radius*2+1,:-radius*2+1]

这比原始函数快得离谱。对于一个1024x1024的数组和半径为20,旧函数需要34.11秒,而新函数只需要0.11秒,加速了300倍。


这个技巧是如何数学计算的呢?它为每个窗口计算量sqrt(mean(x^2) - mean(x)^2)。我们可以从标准差sqrt(mean((x - mean(x))^2))推导出这个量:

E为期望运算符(基本上就是mean()),X是数据的随机变量。则有:

E[(X - E[X])^2]
= E[X^2 - 2X*E[X] + E[X]^2]
= E[X^2] - E[2X*E[X]] + E[E[X]^2] (由期望运算符的线性性)
= E[X^2] - 2E[X]*E[X] + E[X]^2 (再次由线性性以及E[X]为常数的事实)
= E[X^2] - E[X]^2

这证明了使用这种技术计算出的量在数学上等价于标准差。


这个解决方案看起来非常聪明,但我对它感到不舒服:它似乎在每个邻居上计算平方和平均值的差的平方根。对我来说,这与每个值与平均值之间的差的平方的平均值的平方根是不同的。换句话说,mean(x^2)-mean(x)^2 不等于 mean(x^2-mean(x)^2)。你怎么看? - baptiste pagneux
3
@baptistepagneux:你可以证明它是正确的。快速证明:设x为平均值,X为随机变量,E为期望操作符(与“mean()”相同)。那么, E [(X-x) ^ 2] = E [X ^ 2 - 2Xx + x ^ 2] = E [X ^ 2] - 2E [X] x + x ^ 2 (最后一个等式由于 E 的线性性和 x 是常数)= E [X ^ 2] - x ^ 2(因为根据定义 E [X] = x)=E [X ^ 2] - E [X] ^ 2。证毕。 - nneonneo
我不会使用这个答案 - 它会产生错误的结果。请查看下面我的答案,以获取准确的解决方案。 - Max Jaderberg
1
@MaxJaderberg:能否解释一下为什么或者如何会产生错误的结果?我包含了一个证明来展示这个数学是正确的。 - nneonneo
1
@MaxJaderberg:不需要,你可以自己检查一下。uniform_filter已经执行了归一化操作。在一个随机生成的图像(np.random.random((100,100)))上,我的解决方案产生的结果在所有像素上的误差最大为1.4e-15,与OP代码给出的结果相比。 - nneonneo
显示剩余4条评论

13

在图像处理中,最常用的方法是使用积分图表来完成这种任务。这个想法最初是在1984年介绍在这篇论文中。这个想法是,当您通过在窗口上加法计算数量,并将窗口移动一个像素到右侧时,您不需要添加新窗口中的所有项目,只需要从总数中减去最左边的列,然后添加新的最右边的列即可。因此,如果您从数组的两个维度创建一个累积和数组,则可以通过几个总和和一个减法来获得窗口内的总和。如果您为您的数组及其平方保留了累积和表,则非常容易从这两个表中获取方差。以下是一个实现:

def windowed_sum(a, win):
    table = np.cumsum(np.cumsum(a, axis=0), axis=1)
    win_sum = np.empty(tuple(np.subtract(a.shape, win-1)))
    win_sum[0,0] = table[win-1, win-1]
    win_sum[0, 1:] = table[win-1, win:] - table[win-1, :-win]
    win_sum[1:, 0] = table[win:, win-1] - table[:-win, win-1]
    win_sum[1:, 1:] = (table[win:, win:] + table[:-win, :-win] -
                       table[win:, :-win] - table[:-win, win:])
    return win_sum

def windowed_var(a, win):
    win_a = windowed_sum(a, win)
    win_a2 = windowed_sum(a*a, win)
    return (win_a2 - win_a * win_a / win/ win) / win / win
为了看到这个起作用:

要看到这个起作用:

>>> a = np.arange(25).reshape(5,5)
>>> windowed_var(a, 3)
array([[ 17.33333333,  17.33333333,  17.33333333],
       [ 17.33333333,  17.33333333,  17.33333333],
       [ 17.33333333,  17.33333333,  17.33333333]])
>>> np.var(a[:3, :3])
17.333333333333332
>>> np.var(a[-3:, -3:])
17.333333333333332

这应该比基于卷积的方法运行快几个档次。


这似乎会很快。你应该计时一下。 - nneonneo
2
如果您从上面的代码中删除所有不必要的中间数组,并使用np.addnp.subtract以及out关键字,它的运行速度比uniform_filter慢20-30%左右,对于像窗口大小独立性之类的性能表现相似。因此,如果您想跳过scipy依赖项,则这不是一个坏选择。 - Jaime
我建议您将 a = np.int32(a) 添加为 windowed_var(a, win) 中的第一行,以避免问题。 - phyrox
@phyrox 那会以什么方式有所帮助呢? - Jaime
@Jaime,我在使用一个不同类型的数据数组(我想是int32)时遇到了问题,在最后一次除法中它返回了奇怪的数字。我使用了我写的那行代码解决了这个问题。 - phyrox
显示剩余2条评论

3

首先,有多种方法可以完成此任务。

从速度角度来看,这不是最有效的方式,但使用scipy.ndimage.generic_filter函数可以轻松地在移动窗口上应用任意Python函数。

以下是一个快速示例:

result = scipy.ndimage.generic_filter(data, np.std, size=2*radius)

请注意,边界条件可以通过 mode kwarg 进行控制。
另一种方法是使用一些各种分步技巧来创建数组的视图,这个视图实际上是一个移动窗口,然后沿着最后一个轴应用 np.std。(注意:这是从我之前在这里的一个答案中引用的:https://dev59.com/6W445IYBdhLWcg3wTIZf#4947453)
def strided_sliding_std_dev(data, radius=5):
    windowed = rolling_window(data, (2*radius, 2*radius))
    shape = windowed.shape
    windowed = windowed.reshape(shape[0], shape[1], -1)
    return windowed.std(axis=-1)

def rolling_window(a, window):
    """Takes a numpy array *a* and a sequence of (or single) *window* lengths
    and returns a view of *a* that represents a moving window."""
    if not hasattr(window, '__iter__'):
        return rolling_window_lastaxis(a, window)
    for i, win in enumerate(window):
        if win > 1:
            a = a.swapaxes(i, -1)
            a = rolling_window_lastaxis(a, win)
            a = a.swapaxes(-2, i)
    return a

def rolling_window_lastaxis(a, window):
    """Directly taken from Erik Rigtorp's post to numpy-discussion.
    <http://www.mail-archive.com/numpy-discussion@scipy.org/msg29450.html>"""
    if window < 1:
       raise ValueError, "`window` must be at least 1."
    if window > a.shape[-1]:
       raise ValueError, "`window` is too long."
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

一开始可能有点难理解这里正在发生什么。我不想重新输入解释,所以请看这里:https://dev59.com/i2445IYBdhLWcg3wUYrM#4924433,如果你之前没有看过这种“跨步”技巧。
如果我们比较一个100x100的随机浮点数数组和半径为5的radius,它比原始版本或generic_filter版本快约10倍。然而,这个版本在边界条件上没有灵活性。(它与你目前正在做的完全相同,而generic_filter版本在速度上的代价是给你很多灵活性。)
# Your original function with nested loops
In [21]: %timeit sliding_std_dev(data)
1 loops, best of 3: 237 ms per loop

# Using scipy.ndimage.generic_filter
In [22]: %timeit ndimage_std_dev(data)
1 loops, best of 3: 244 ms per loop

# The "stride-tricks" version above
In [23]: %timeit strided_sliding_std_dev(data)
100 loops, best of 3: 15.4 ms per loop

# Ophion's version that uses `np.take`
In [24]: %timeit new_std_dev(data)
100 loops, best of 3: 19.3 ms per loop

“stride-tricks”版本的缺点是,与“正常”的步幅滚动窗口技巧不同,这个版本确实会复制,并且它比原始数组要大得多。如果您在大型数组上使用此功能,您将遇到内存问题!(顺便说一句,从内存使用和速度方面来看,它基本上等同于@Ophion的答案。只是用了不同的方法来做同样的事情。)


这些方法并不真正等效,radius变量正在执行两个不同的事情。 - Daniel
@Ophion - 哎呀!好发现。我的 radius 实际上是直径。 - Joe Kington
此外,显然边界条件的处理方式是不同的,但在边界之外它们是等效的。我主要发布这篇文章只是为了展示不同的技术,而不是为了完全匹配OP处理边界条件的方式。 - Joe Kington
1
这是strided数组的一个很好的应用,但我主要指出这只有在使用正确的半径并且删除hstackvstack部分时才会快大约5%。 - Daniel
2
在numpy 1.7中,虽然没有得到很好的记录,但你可以给np.std一个轴元组而不是单个轴。因此,您可以获得数组的4D视图,例如形状为(rows-win+1, cols-win+1, win, win),然后在该视图上调用.std(axis=(-1, -2)),并获得窗口化标准差而不制作副本。这将使其等效于@nneonneo提出的均匀滤波器。 - Jaime

1
你可以先获取索引,然后使用np.take来形成新数组:
def new_std_dev(image_original,radius=5):
    cols,rows=image_original.shape

    #First obtain the indices for the top left position
    diameter=np.arange(radius*2)
    x,y=np.meshgrid(diameter,diameter)
    index=np.ravel_multi_index((y,x),(cols,rows)).ravel()

    #Cast this in two dimesions and take the stdev
    index=index+np.arange(rows-radius*2)[:,None]+np.arange(cols-radius*2)[:,None,None]*(rows)
    data=np.std(np.take(image_original,index),-1)

    #Add the zeros back to the output array
    top=np.zeros((radius,rows-radius*2))
    sides=np.zeros((cols,radius))

    data=np.vstack((top,data,top))
    data=np.hstack((sides,data,sides))
    return data

首先生成一些随机数据并检查时间:

a=np.random.rand(50,20)

print np.allclose(new_std_dev(a),sliding_std_dev(a))
True

%timeit sliding_std_dev(a)
100 loops, best of 3: 18 ms per loop

%timeit new_std_dev(a)
1000 loops, best of 3: 472 us per loop

对于更大的数组,只要您拥有足够的内存,它总是更快的:

a=np.random.rand(200,200)

print np.allclose(new_std_dev(a),sliding_std_dev(a))
True

%timeit sliding_std_dev(a)
1 loops, best of 3: 1.58 s per loop

%timeit new_std_dev(a)
10 loops, best of 3: 52.3 ms per loop

原始函数在处理非常小的数组时速度更快,看起来当hgt*wdt > 50时达到了平衡点。需要注意的是您的函数正在处理正方形框架,并将标准偏差放置在右下角索引中,而不是在索引周围进行采样。这是有意为之吗?

感谢@Ophion提供的解决方案,并指出了我原始函数中的错误。实际上,我的意思是:result[i,j] = np.std(image_original [i-radius:i+radius+1,j-radius:j+radius+1]),以使窗口围绕像素居中。为了使用您的算法获得相同的结果,我修改了diameter = np.arange(radius * 2 + 1),它似乎给出了与我修改后的原始函数相同的结果。我花了一段时间才理解您的解决方案,但它确实非常出色和高效。再次感谢。 - baptiste pagneux
我会选择@nneonneo的解决方案。你的解决方案很快,但是在处理大图像时我会遇到内存错误(这对我来说很奇怪,但我不知道底层发生了什么)。无论如何,再次感谢你。 - baptiste pagneux
使用np.take时,它将创建一个三维数组,看起来像这样(k,N,N),其中k是第k个窗口,NxN是窗口组件。正如您所看到的,这会大量复制索引,极大地增加了您的内存占用量,可能为arr_size*N*N。@nneonneo的解决方案是一种优化的例程,不会复制窗口,因此只需要大小为constant*arr_size的数组。从代码的快速浏览中可以看出,constant相当小,看起来像是2-3。总体而言,这是一种更好的做事方式。 - Daniel

0

在尝试使用这里的几个优秀解决方案后,我遇到了包含NaN的数据问题。无论是uniform_filter还是np.cumsum()解决方案都会导致NaN通过输出数组传播而不仅仅是被忽略。

我的解决方案基本上只是将@Jaime答案中的窗口求和函数与卷积交换,这对NaN是稳健的。

def windowed_sum(arr: np.ndarray, radius: int) -> np.ndarray:
    """radius=1 means the pixel itself and the 8 surrounding pixels"""

    kernel = np.ones((radius * 2 + 1, radius * 2 + 1), dtype=int)
    return convolve(arr, kernel, mode="constant", cval=0.0)

def windowed_var(arr: np.ndarray, radius: int) -> np.ndarray:
    """Note: this returns smaller in size than the input array (by radius)"""

    diameter = radius * 2 + 1
    win_sum = windowed_sum(arr, radius)[radius:-radius, radius:-radius]
    win_sum_2 = windowed_sum(arr * arr, radius)[radius:-radius, radius:-radius]
    return (win_sum_2 - win_sum * win_sum / diameter / diameter) / diameter / diameter

def windowed_std(arr: np.ndarray, radius: int) -> np.ndarray:

    output = np.full_like(arr, np.nan, dtype=np.float64)

    var_arr = windowed_var(arr, radius)
    std_arr = np.sqrt(var_arr)
    output[radius:-radius, radius:-radius] = std_arr

    return output

这个执行速度比uniform_filter慢一点,但仍然比许多其他方法(如数组堆叠、迭代等)要快得多。

>>> data = np.random.random((1024, 1024))
>>> %timeit windowed_std(data, 4)
158 ms ± 695 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

与执行相同大小数据的uniform_filter相比,该算法仅需约36毫秒

包含一些NaN值:

data = np.arange(100, dtype=np.float64).reshape(10, 10)
data[3:4, 3:4] = np.nan
windowed_std(data, 1)

array([[ nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan],
       [ nan, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21,  nan,  nan,  nan, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21,  nan,  nan,  nan, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21,  nan,  nan,  nan, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan]])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接