在Python中实现MATLAB的im2col“滑动”功能

21

问:如何加快这个过程?

以下是我实现的Matlab的im2col'滑动'函数,并增加了每n个列返回的功能。该函数接受一个图像(或任何二维数组),从左到右、从上到下地滑动,获取给定大小的每个重叠子图像,并返回其列数组。

import numpy as np

def im2col_sliding(image, block_size, skip=1):

    rows, cols = image.shape
    horz_blocks = cols - block_size[1] + 1
    vert_blocks = rows - block_size[0] + 1

    output_vectors = np.zeros((block_size[0] * block_size[1], horz_blocks * vert_blocks))
    itr = 0
    for v_b in xrange(vert_blocks):
        for h_b in xrange(horz_blocks):
            output_vectors[:, itr] = image[v_b: v_b + block_size[0], h_b: h_b + block_size[1]].ravel()
            itr += 1

    return output_vectors[:, ::skip]

例子:

a = np.arange(16).reshape(4, 4)
print a
print im2col_sliding(a, (2, 2))  # return every overlapping 2x2 patch
print im2col_sliding(a, (2, 2), 4)  # return every 4th vector

返回:

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
[[  0.   1.   2.   4.   5.   6.   8.   9.  10.]
 [  1.   2.   3.   5.   6.   7.   9.  10.  11.]
 [  4.   5.   6.   8.   9.  10.  12.  13.  14.]
 [  5.   6.   7.   9.  10.  11.  13.  14.  15.]]
[[  0.   5.  10.]
 [  1.   6.  11.]
 [  4.   9.  14.]
 [  5.  10.  15.]]

这个性能并不出色,特别是当我调用 im2col_sliding(big_matrix, (8, 8))(62001列)或者 im2col_sliding(big_matrix, (8, 8), 10)(6201列;只保留每10个向量),它们需要相同的时间 [其中big_matrix的大小为256 x 256]。

我正在寻找任何可以加速这个过程的想法。


这个回答对你有帮助吗?还是你特别想加速你自己的代码? - ljetibo
@ljetibo 我看过并尝试了那篇帖子中被接受的答案,但没有将其扩展到我想要的功能。我对任何解决方案都持开放态度。 - Scott
6个回答

32

方法一

我们可以使用一些广播技术来一次性获取所有滑动窗口的索引,从而通过索引实现向量化解决方案。这受到了im2col和col2im的高效实现的启发。

以下是实现代码 -

def im2col_sliding_broadcasting(A, BSZ, stepsize=1):
    # Parameters
    M,N = A.shape
    col_extent = N - BSZ[1] + 1
    row_extent = M - BSZ[0] + 1
    
    # Get Starting block indices
    start_idx = np.arange(BSZ[0])[:,None]*N + np.arange(BSZ[1])
    
    # Get offsetted indices across the height and width of input array
    offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent)
    
    # Get all actual indices & index into input array for final output
    return np.take (A,start_idx.ravel()[:,None] + offset_idx.ravel()[::stepsize])

方法二

利用新学到的NumPy数组步幅知识,我们可以创建这样的滑动窗口,从而得到另一种高效的解决方案 -

def im2col_sliding_strided(A, BSZ, stepsize=1):
    # Parameters
    m,n = A.shape
    s0, s1 = A.strides    
    nrows = m-BSZ[0]+1
    ncols = n-BSZ[1]+1
    shp = BSZ[0],BSZ[1],nrows,ncols
    strd = s0,s1,s0,s1
    
    out_view = np.lib.stride_tricks.as_strided(A, shape=shp, strides=strd)
    return out_view.reshape(BSZ[0]*BSZ[1],-1)[:,::stepsize]

方法三

前面提到的跨步方法已经被整合到scikit-image模块中,使得代码更加简洁易读,使用如下:

from skimage.util import view_as_windows as viewW

def im2col_sliding_strided_v2(A, BSZ, stepsize=1):
    return viewW(A, (BSZ[0],BSZ[1])).reshape(-1,BSZ[0]*BSZ[1]).T[:,::stepsize]

样例运行 -

In [106]: a      # Input array
Out[106]: 
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])

In [107]: im2col_sliding_broadcasting(a, (2,3))
Out[107]: 
array([[ 0,  1,  2,  5,  6,  7, 10, 11, 12],
       [ 1,  2,  3,  6,  7,  8, 11, 12, 13],
       [ 2,  3,  4,  7,  8,  9, 12, 13, 14],
       [ 5,  6,  7, 10, 11, 12, 15, 16, 17],
       [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
       [ 7,  8,  9, 12, 13, 14, 17, 18, 19]])

In [108]: im2col_sliding_broadcasting(a, (2,3), stepsize=2)
Out[108]: 
array([[ 0,  2,  6, 10, 12],
       [ 1,  3,  7, 11, 13],
       [ 2,  4,  8, 12, 14],
       [ 5,  7, 11, 15, 17],
       [ 6,  8, 12, 16, 18],
       [ 7,  9, 13, 17, 19]])

运行时测试

In [183]: a = np.random.randint(0,255,(1024,1024))

In [184]: %timeit im2col_sliding(img, (8,8), skip=1)
     ...: %timeit im2col_sliding_broadcasting(img, (8,8), stepsize=1)
     ...: %timeit im2col_sliding_strided(img, (8,8), stepsize=1)
     ...: %timeit im2col_sliding_strided_v2(img, (8,8), stepsize=1)
     ...: 
1 loops, best of 3: 1.29 s per loop
1 loops, best of 3: 226 ms per loop
10 loops, best of 3: 84.5 ms per loop
10 loops, best of 3: 111 ms per loop

In [185]: %timeit im2col_sliding(img, (8,8), skip=4)
     ...: %timeit im2col_sliding_broadcasting(img, (8,8), stepsize=4)
     ...: %timeit im2col_sliding_strided(img, (8,8), stepsize=4)
     ...: %timeit im2col_sliding_strided_v2(img, (8,8), stepsize=4)
     ...: 
1 loops, best of 3: 1.31 s per loop
10 loops, best of 3: 104 ms per loop
10 loops, best of 3: 84.4 ms per loop
10 loops, best of 3: 109 ms per loop

使用步长方法比原始循环版本快了16倍

4
我知道我不应该(违反规则),但这实在让我大吃一惊。花了一堆整齐的纸来弄清楚为什么这个东西有效,里面涉及了许多巧妙地使用了numpy矩阵操作的属性。如果可以的话,我会给你买一杯啤酒\点两次赞……非常感谢,这刚刚让我的晚上变得更美好了。 - ljetibo
2
@ljetibo 谢谢您!您的留言也让我感到很开心!我从 MATLAB 转向 Numpy,虽然不知道如何正确使用 for 循环,但这可能是好是坏,但我非常喜欢它,特别是因为在 numpy 中循环似乎很昂贵。最近,我偶然发现了这个方便的工具 “np.take”,以前在任何 SO 问题上都无法使用它,但在这里却完美契合。感谢您的赞美! :) - Divakar
1
@Scott,你不喜欢这些东西吗? ;) 我知道我喜欢!愿原力与它同在! - Divakar
1
@Divakar在MATLAB标签上非常受欢迎。他的答案是PFM。 - rayryeng
1
@rayryeng,你是从哪里学来这些文本缩写的啊? ;) - Divakar
显示剩余7条评论

9

对于在不同图像通道上进行滑动窗口,我们可以使用Divakar提供的代码的更新版本@“在Python中实现MATLAB的im2col 'sliding'”

import numpy as np
A = np.random.randint(0,9,(2,4,4)) # Sample input array
                    # Sample blocksize (rows x columns)
B = [2,2]
skip=[2,2]
# Parameters 
D,M,N = A.shape
col_extent = N - B[1] + 1
row_extent = M - B[0] + 1

# Get Starting block indices
start_idx = np.arange(B[0])[:,None]*N + np.arange(B[1])

# Generate Depth indeces
didx=M*N*np.arange(D)
start_idx=(didx[:,None]+start_idx.ravel()).reshape((-1,B[0],B[1]))

# Get offsetted indices across the height and width of input array
offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent)

# Get all actual indices & index into input array for final output
out = np.take (A,start_idx.ravel()[:,None] + offset_idx[::skip[0],::skip[1]].ravel())

测试 示例运行

A=
[[[6 2 8 5]
[6 4 7 6]
[8 6 5 2]
[3 1 3 7]]

[[6 0 4 3]
[7 6 4 6]
[2 6 7 1]
[7 6 7 7]]]

out=
[6 8 8 5]
[2 5 6 2]
[6 7 3 3]
[4 6 1 7]
[6 4 2 7]
[0 3 6 1]
[7 4 7 7]
[6 6 6 7]

感谢您的整理,这可能会很有用。您可以考虑添加一个%timeit部分,将其与糟糕的循环进行比较。 - Scott
感谢您提供更通用的版本! - Divakar

3
为了进一步提高性能(例如在卷积方面),我们还可以使用基于批处理的实现,这是由M Elyia提供的扩展代码,可以在Python中实现Matlab的im2col“滑动”功能,即:@Implement Matlab's im2col 'sliding' in python
import numpy as np

A = np.arange(3*1*4*4).reshape(3,1,4,4)+1 # 3 Sample input array with 1 channel
B = [2,2] # Sample blocksize (rows x columns)
skip = [2,2]

# Parameters 
batch, D,M,N = A.shape
col_extent = N - B[1] + 1
row_extent = M - B[0] + 1

# Get batch block indices
batch_idx = np.arange(batch)[:, None, None] * D * M * N

# Get Starting block indices
start_idx = np.arange(B[0])[None, :,None]*N + np.arange(B[1])

# Generate Depth indeces
didx=M*N*np.arange(D)
start_idx=(didx[None, :, None]+start_idx.ravel()).reshape((-1,B[0],B[1]))

# Get offsetted indices across the height and width of input array
offset_idx = np.arange(row_extent)[None, :, None]*N + np.arange(col_extent)

# Get all actual indices & index into input array for final output
act_idx = (batch_idx + 
    start_idx.ravel()[None, :, None] + 
    offset_idx[:,::skip[0],::skip[1]].ravel())

out = np.take (A, act_idx)

测试样本运行:

A = 
[[[[ 1  2  3  4]
   [ 5  6  7  8]
   [ 9 10 11 12]
   [13 14 15 16]]]


 [[[17 18 19 20]
   [21 22 23 24]
   [25 26 27 28]
   [29 30 31 32]]]


 [[[33 34 35 36]
   [37 38 39 40]
   [41 42 43 44]
   [45 46 47 48]]]] 


out = 
[[[ 1  2  3  9 10 11]
  [ 2  3  4 10 11 12]
  [ 5  6  7 13 14 15]
  [ 6  7  8 14 15 16]]

 [[17 18 19 25 26 27]
  [18 19 20 26 27 28]
  [21 22 23 29 30 31]
  [22 23 24 30 31 32]]

 [[33 34 35 41 42 43]
  [34 35 36 42 43 44]
  [37 38 39 45 46 47]
  [38 39 40 46 47 48]]]

2
我使用Numba JIT编译器实现了快速解决方案。根据块大小和跳过大小,它可以提供从5.67倍3597倍的加速效果。
加速效果指的是Numba算法相对于原始算法的速度增长倍数,例如20倍的加速效果意味着如果原始算法需要200ms,那么快速的Numba算法只需要10ms
我的代码需要通过如下命令安装pip模块:python -m pip install numpy numba timerit matplotlib
接下来是代码、速度增益图和时间测量的控制台输出。

在线试一试!

import numpy as np

# ----- Original Implementation -----

def im2col_sliding(image, block_size, skip = 1):
    rows, cols = image.shape
    horz_blocks = cols - block_size[1] + 1
    vert_blocks = rows - block_size[0] + 1
    
    if vert_blocks <= 0 or horz_blocks <= 0:
        return np.zeros((block_size[0] * block_size[1], 0), dtype = image.dtype)

    output_vectors = np.zeros((block_size[0] * block_size[1], horz_blocks * vert_blocks), dtype = image.dtype)
    itr = 0
    
    for v_b in range(vert_blocks):
        for h_b in range(horz_blocks):
            output_vectors[:, itr] = image[v_b: v_b + block_size[0], h_b: h_b + block_size[1]].ravel()
            itr += 1

    return output_vectors[:, ::skip]


# ----- Fast Numba Implementation -----
    
import numba

@numba.njit(cache = True)
def im2col_sliding_numba(image, block_size, skip = 1):
    assert skip >= 1
    rows, cols = image.shape
    horz_blocks = cols - block_size[1] + 1
    vert_blocks = rows - block_size[0] + 1
    
    if vert_blocks <= 0 or horz_blocks <= 0:
        return np.zeros((block_size[0] * block_size[1], 0), dtype = image.dtype)
    
    res = np.zeros((block_size[0] * block_size[1], (horz_blocks * vert_blocks + skip - 1) // skip), dtype = image.dtype)
    itr, to_skip, v_b = 0, 0, 0
    
    while True:
        v_b += to_skip // horz_blocks
        if v_b >= vert_blocks:
            break
        h_b_start = to_skip % horz_blocks
        h_cnt = (horz_blocks - h_b_start + skip - 1) // skip
        for i, h_b in zip(range(itr, itr + h_cnt), range(h_b_start, horz_blocks, skip)):
            ii = 0
            for iv in range(v_b, v_b + block_size[0]):
                for ih in range(h_b, h_b + block_size[1]):
                    res[ii, i] = image[iv, ih]
                    ii += 1
        to_skip = skip - (horz_blocks - h_b_start - skip * (h_cnt - 1))
        itr += h_cnt
        v_b += 1
        
    assert itr == res.shape[1]#, (itr, res.shape)

    return res


# ----- Testing -----

from timerit import Timerit
Timerit._default_asciimode = True

side = 256
a = np.random.randint(0, 256, (side, side), dtype = np.uint8)

stats = []

for block_size in [16, 8, 4, 2, 1]:
    for skip_size in [1, 2, 5, 11, 23]:
        print(f'block_size {block_size} skip_size {skip_size}', flush = True)
        for ifn, f in enumerate([im2col_sliding, im2col_sliding_numba]):
            print(f'{f.__name__}: ', end = '', flush = True)
            tim = Timerit(num = 3, verbose = 1)
            for i, t in enumerate(tim):
                if i == 0 and ifn == 1:
                    f(a, (block_size, block_size), skip_size)
                with t:
                    r = f(a, (block_size, block_size), skip_size)
            rt = tim.mean()
            if ifn == 0:
                bt, ba = rt, r
            else:
                assert np.array_equal(ba, r)
                print(f'speedup {round(bt / rt, 2)}x')
                stats.append({
                    'block_size': block_size,
                    'skip_size': skip_size,
                    'speedup': bt / rt,
                })

stats = sorted(stats, key = lambda e: e['speedup'])

import math, matplotlib, matplotlib.pyplot as plt

x = np.arange(len(stats))
y = np.array([e['speedup'] for e in stats])

plt.rcParams['figure.figsize'] = (12.8, 7.2)

for scale in ['linear', 'log']:
    plt.clf()
    plt.xlabel('iteration')
    plt.ylabel(f'speedup_{scale}')
    plt.yscale(scale)
    plt.scatter(x, y, marker = '.')
    for i in range(x.size):
        plt.annotate(
            (f"b{str(stats[i]['block_size']).zfill(2)}s{str(stats[i]['skip_size']).zfill(2)}\n" +
             f"x{round(stats[i]['speedup'], 2 if stats[i]['speedup'] < 100 else 1 if stats[i]['speedup'] < 1000 else None)}"),
            (x[i], y[i]), fontsize = 'small',
        )
    plt.subplots_adjust(left = 0.055, right = 0.99, bottom = 0.08, top = 0.99)
    plt.xlim(left = -0.1)
    if scale == 'linear':
        ymin, ymax = np.amin(y), np.amax(y)
        plt.ylim((ymin - (ymax - ymin) * 0.02, ymax + (ymax - ymin) * 0.05))
        plt.yticks([ymin] + [e for e in plt.yticks()[0] if ymin + 0.01 < e < ymax - 0.01] + [ymax])
        #plt.gca().get_yaxis().set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
    plt.savefig(f'im2col_numba_{scale}.png', dpi = 150)
    plt.show()

下一个图表的横轴为迭代次数,纵轴为加速比。第一个图表的纵轴是线性的,第二个图表的纵轴是对数的。每个点还有标签 bXXsYYxZZ ,其中 XX 表示块大小, YY 表示跳过(步长)大小, ZZ 表示加速比。
线性图表:

linear plot

对数图表:

log plot

控制台输出:
block_size 16 skip_size 1
im2col_sliding: Timed best=549.069 ms, mean=549.069 +- 0.0 ms
im2col_sliding_numba: Timed best=96.841 ms, mean=96.841 +- 0.0 ms
speedup 5.67x
block_size 16 skip_size 2
im2col_sliding: Timed best=559.396 ms, mean=559.396 +- 0.0 ms
im2col_sliding_numba: Timed best=71.132 ms, mean=71.132 +- 0.0 ms
speedup 7.86x
block_size 16 skip_size 5
im2col_sliding: Timed best=561.030 ms, mean=561.030 +- 0.0 ms
im2col_sliding_numba: Timed best=15.000 ms, mean=15.000 +- 0.0 ms
speedup 37.4x
block_size 16 skip_size 11
im2col_sliding: Timed best=559.045 ms, mean=559.045 +- 0.0 ms
im2col_sliding_numba: Timed best=6.719 ms, mean=6.719 +- 0.0 ms
speedup 83.21x
block_size 16 skip_size 23
im2col_sliding: Timed best=562.462 ms, mean=562.462 +- 0.0 ms
im2col_sliding_numba: Timed best=2.514 ms, mean=2.514 +- 0.0 ms
speedup 223.72x
block_size 8 skip_size 1
im2col_sliding: Timed best=373.790 ms, mean=373.790 +- 0.0 ms
im2col_sliding_numba: Timed best=17.441 ms, mean=17.441 +- 0.0 ms
speedup 21.43x
block_size 8 skip_size 2
im2col_sliding: Timed best=375.858 ms, mean=375.858 +- 0.0 ms
im2col_sliding_numba: Timed best=8.791 ms, mean=8.791 +- 0.0 ms
speedup 42.75x
block_size 8 skip_size 5
im2col_sliding: Timed best=376.767 ms, mean=376.767 +- 0.0 ms
im2col_sliding_numba: Timed best=3.115 ms, mean=3.115 +- 0.0 ms
speedup 120.94x
block_size 8 skip_size 11
im2col_sliding: Timed best=378.284 ms, mean=378.284 +- 0.0 ms
im2col_sliding_numba: Timed best=1.406 ms, mean=1.406 +- 0.0 ms
speedup 268.97x
block_size 8 skip_size 23
im2col_sliding: Timed best=376.268 ms, mean=376.268 +- 0.0 ms
im2col_sliding_numba: Timed best=661.404 us, mean=661.404 +- 0.0 us
speedup 568.89x
block_size 4 skip_size 1
im2col_sliding: Timed best=378.813 ms, mean=378.813 +- 0.0 ms
im2col_sliding_numba: Timed best=4.950 ms, mean=4.950 +- 0.0 ms
speedup 76.54x
block_size 4 skip_size 2
im2col_sliding: Timed best=377.620 ms, mean=377.620 +- 0.0 ms
im2col_sliding_numba: Timed best=2.119 ms, mean=2.119 +- 0.0 ms
speedup 178.24x
block_size 4 skip_size 5
im2col_sliding: Timed best=374.792 ms, mean=374.792 +- 0.0 ms
im2col_sliding_numba: Timed best=854.986 us, mean=854.986 +- 0.0 us
speedup 438.36x
block_size 4 skip_size 11
im2col_sliding: Timed best=373.296 ms, mean=373.296 +- 0.0 ms
im2col_sliding_numba: Timed best=415.028 us, mean=415.028 +- 0.0 us
speedup 899.45x
block_size 4 skip_size 23
im2col_sliding: Timed best=374.075 ms, mean=374.075 +- 0.0 ms
im2col_sliding_numba: Timed best=219.491 us, mean=219.491 +- 0.0 us
speedup 1704.28x
block_size 2 skip_size 1
im2col_sliding: Timed best=377.698 ms, mean=377.698 +- 0.0 ms
im2col_sliding_numba: Timed best=1.477 ms, mean=1.477 +- 0.0 ms
speedup 255.67x
block_size 2 skip_size 2
im2col_sliding: Timed best=378.155 ms, mean=378.155 +- 0.0 ms
im2col_sliding_numba: Timed best=841.298 us, mean=841.298 +- 0.0 us
speedup 449.49x
block_size 2 skip_size 5
im2col_sliding: Timed best=376.381 ms, mean=376.381 +- 0.0 ms
im2col_sliding_numba: Timed best=392.541 us, mean=392.541 +- 0.0 us
speedup 958.83x
block_size 2 skip_size 11
im2col_sliding: Timed best=374.720 ms, mean=374.720 +- 0.0 ms
im2col_sliding_numba: Timed best=193.093 us, mean=193.093 +- 0.0 us
speedup 1940.62x
block_size 2 skip_size 23
im2col_sliding: Timed best=378.092 ms, mean=378.092 +- 0.0 ms
im2col_sliding_numba: Timed best=105.101 us, mean=105.101 +- 0.0 us
speedup 3597.42x
block_size 1 skip_size 1
im2col_sliding: Timed best=203.410 ms, mean=203.410 +- 0.0 ms
im2col_sliding_numba: Timed best=686.335 us, mean=686.335 +- 0.0 us
speedup 296.37x
block_size 1 skip_size 2
im2col_sliding: Timed best=202.865 ms, mean=202.865 +- 0.0 ms
im2col_sliding_numba: Timed best=361.255 us, mean=361.255 +- 0.0 us
speedup 561.56x
block_size 1 skip_size 5
im2col_sliding: Timed best=200.929 ms, mean=200.929 +- 0.0 ms
im2col_sliding_numba: Timed best=164.740 us, mean=164.740 +- 0.0 us
speedup 1219.68x
block_size 1 skip_size 11
im2col_sliding: Timed best=202.163 ms, mean=202.163 +- 0.0 ms
im2col_sliding_numba: Timed best=96.791 us, mean=96.791 +- 0.0 us
speedup 2088.65x
block_size 1 skip_size 23
im2col_sliding: Timed best=202.492 ms, mean=202.492 +- 0.0 ms
im2col_sliding_numba: Timed best=64.527 us, mean=64.527 +- 0.0 us
speedup 3138.1x

我之前不知道 Numba。谢谢,这是一个很棒的解决方案。 - Scott
@Scott Numba是一个特殊的Python模块,它允许将任何相当简单的Python函数转换为C++优化代码并将其编译为机器代码。如果函数有很多循环和与numpy数组的交互,则所有这些操作都会转换为相应的C++操作。这样,numba平均可以将任何Python代码优化以使其运行速度快50倍至300倍。不需要特殊知识,只需将"@numba.njit"装饰器添加到函数中即可完成! - Arty

0

你也可以对M Eliya的答案进行进一步优化(虽然不是很显著)

在最后阶段“应用”skip之前,你可以在生成偏移数组时应用它,这样就不需要:

# Get offsetted indices across the height and width of input array
offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent)

# Get all actual indices & index into input array for final output
out = np.take (A,start_idx.ravel()[:,None] + offset_idx[::skip[0],::skip[1]].ravel())

你可以使用numpy的arange函数的step参数来添加跳过:
# Get offsetted indices across the height and width of input array and add skips
offset_idx = np.arange(row_extent, step=skip[0])[:, None] * N + np.arange(col_extent, step=skip[1])

然后只需添加偏移数组,无需使用[::]索引

# Get all actual indices & index into input array for final output

out = np.take(A, start_idx.ravel()[:, None] + offset_idx.ravel())

在小的跳过值上,它几乎不会节省任何时间:

In[25]:
A = np.random.randint(0,9,(3, 1024, 1024))
B = [2, 2]
skip = [2, 2]

In[26]: %timeit im2col(A, B, skip)
10 loops, best of 3: 19.7 ms per loop

In[27]: %timeit im2col_optimized(A, B, skip)
100 loops, best of 3: 17.5 ms per loop

然而,使用更大的跳过值可以节省更多时间:

In[28]: skip = [10, 10]
In[29]: %timeit im2col(A, B, skip)
100 loops, best of 3: 3.85 ms per loop

In[30]: %timeit im2col_optimized(A, B, skip)
1000 loops, best of 3: 1.02 ms per loop

A = np.random.randint(0,9,(3, 2000, 2000))
B = [10, 10]
skip = [10, 10]

In[43]: %timeit im2col(A, B, skip)
10 loops, best of 3: 87.8 ms per loop

In[44]: %timeit im2col_optimized(A, B, skip)
10 loops, best of 3: 76.3 ms per loop

0

我认为你做不得更好了。显然,你必须运行一个大小为

cols - block_size[1] * rows - block_size[0]

的循环。但是在你的例子中,你使用的是3x3的补丁而不是2x2的。


在我的例子中,列的长度为4(2x2)。我是不是出现了错误或打错了什么? - Scott
不,只是因为你在2、2上调用了函数,它会给出3、3的矩阵,这让人感到困惑。就是这样。 - InfiniteLooper

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接