Python:利用6维数组实现的im2col?

6
我正在阅读一本深度学习书籍(第7章,CNN),其中包含了实现 im2col 的代码(链接)。它的目的是将四维数组转换为二维数组。但我不知道为什么该实现中有一个六维数组。我非常想知道作者使用该算法的思想背后的原理。
我已经尝试搜索很多关于 im2col 实现的论文,但没有一个像这样使用高维数组。我找到的有用于可视化 im2col 过程的材料是此论文的图片(链接)
def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
    """
    Parameters
    ----------
    input_data : (batch size, channel, height, width), or (N,C,H,W) at below
    filter_h : kernel height
    filter_w : kernel width
    stride : size of stride
    pad : size of padding
    Returns
    -------
    col : two dimensional array
    """
    N, C, H, W = input_data.shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1

    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col
2个回答

11

让我们试着想象一下im2col的作用。它以颜色图像堆栈作为输入,该堆栈具有图像ID、颜色通道、垂直位置和水平位置的维度。为了简单起见,假设我们只有一张图片:

enter image description here

首先它会进行填充:

enter image description here

接下来,它将其分成窗口。窗口的大小由filter_h/w控制,重叠由strides控制。

enter image description here

这是六个维度的来源:图片ID(在示例中缺失,因为我们只有一张图片),网格高度/宽度,颜色通道,窗口高度/宽度。

enter image description here

目前的算法有些笨拙,它将输出按错误的维度顺序组装起来,然后需要使用transpose进行更正。

最好一开始就做对:

def im2col_better(input_data, filter_h, filter_w, stride=1, pad=0):
    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    N, C, H, W = img.shape
    out_h = (H - filter_h)//stride + 1
    out_w = (W - filter_w)//stride + 1
    col = np.zeros((N, out_h, out_w, C, filter_h, filter_w))
    for y in range(out_h):
        for x in range(out_w):
            col[:, y, x] = img[
                ..., y*stride:y*stride+filter_h, x*stride:x*stride+filter_w]
    return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))

顺便提一下:我们可以使用stride_tricks来做得更好,避免嵌套的for循环:

def im2col_best(input_data, filter_h, filter_w, stride=1, pad=0):
    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    N, C, H, W = img.shape
    NN, CC, HH, WW = img.strides
    out_h = (H - filter_h)//stride + 1
    out_w = (W - filter_w)//stride + 1
    col = np.lib.stride_tricks.as_strided(img, (N, out_h, out_w, C, filter_h, filter_w), (NN, stride * HH, stride * WW, CC, HH, WW)).astype(float)
    return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))

算法最后要做的事情是重新塑形,合并前三个维度(在我们的示例中仅为两个,因为只有一张图像)。红色箭头显示了如何将单独的窗口对齐到第一个新维度中:

enter image description here

最后三个维度,颜色通道、窗口中的y坐标和窗口中的x坐标合并为第二个输出维度。每个像素按照黄色箭头所示排列:

enter image description here


非常感谢!但为什么不把它命名为im2row呢? - undefined

3
看起来这个函数只是将每个N张图片中的每个C颜色通道重新排列成一个(out_h x out_w)大小的重叠图像块网格,大小为(filter_h x filter_w),然后将其展平成一个二维数组,其中每一行都是图像块中像素的向量。
在转置和重塑之前,6-D col的尺寸如下: [sample, channel, y_position_within_patch, x_position_within_patch, y_patch_index, x_patch_index]
例如,col[n,c,:,:,i,j]将是一个2-D图像块(在图像块网格中从上到下的第i个块,从左到右的第j个块)。
在转置和重塑之后,col[n*c*i*j,:]将指代相同的图像块,但所有像素都被展平成一个向量。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接