使用PyTorch验证卷积定理

4

这个定理基本上可以表述为:

F(f*g) = F(f)xF(g)

我知道这个定理,但是我无法使用PyTorch重现结果。

以下是可重现的代码:

import torch
import torch.nn.functional as F

# calculate f*g
f = torch.ones((1,1,5,5))
g = torch.tensor(list(range(9))).view(1,1,3,3).float()
conv = F.conv2d(f, g, bias=None, padding=2)

# calculate F(f*g)
F_fg = torch.rfft(conv, signal_ndim=2, onesided=False)

# calculate F x G
f = f.squeeze()
g = g.squeeze()

# need to pad into at least [w1+w2-1, h1+h2-1], which is 7 in our case.
size = f.size(0) + g.size(0) - 1 

f_new = torch.zeros((7,7))
g_new = torch.zeros((7,7))

f_new[1:6,1:6] = f
g_new[2:5,2:5] = g

F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
FxG = torch.mul(F_f, F_g)

print(FxG - F_fg)

这里是 print(FxG - F_fg) 的结果

tensor([[[[[ 0.0000e+00,  0.0000e+00],
       [ 4.1426e+02,  1.7270e+02],
       [-3.6546e+01,  4.7600e+01],
       [-1.0216e+01, -4.1198e+01],
       [-1.0216e+01, -2.0223e+00],
       [-3.6546e+01, -6.2804e+01],
       [ 4.1426e+02, -1.1427e+02]],

      ...

      [[ 4.1063e+02, -2.2347e+02],
       [-7.6294e-06,  2.2817e+01],
       [-1.9024e+01, -9.0105e+00],
       [ 7.1708e+00, -4.1027e+00],
       [-2.6739e+00, -1.1121e+01],
       [ 8.8471e+00,  7.1710e+00],
       [ 4.2528e+01,  9.7559e+01]]]]])

您可以看到,差异并不总是为0。

有人能告诉我为什么,并且如何正确地做这件事吗?

谢谢。


2
在CNN文献中所谓的“卷积”实际上在信号处理术语中被称为相关滤波。基本上,在CNN中,在滑动和乘法之前,核心不会被翻转。尝试使用F_g = torch.rfft(g_new.flip(0).flip(1), ...,这应该可以让您更接近结果。由于DFT假定信号是周期性的(对于傅里叶变换是离散的必要条件),因此可能还存在一些填充差异。我稍后会验证这一点。 - jodag
1个回答

9

所以我更仔细地研究了你目前的工作。我在你的代码中发现了三个错误源。我会在这里逐一进行说明。

1. 复杂运算

当前PyTorch不支持复数的乘法(据我所知)。FFT操作只是返回一个带有实部和虚部的张量。因此,我们需要显式编写复数乘法,而不使用torch.mul*运算符。

(a + ib) * (c + id) = (a*c - b*d) + i(a*d + b*c)

2. 卷积的定义

CNN文献中使用的“卷积”的定义与讨论卷积定理时使用的定义实际上是不同的。我不会详细介绍,但理论定义在滑动和乘法之前反转核。相反,pytorch、tensorflow、caffe等等中的卷积操作并不进行反转。

为了解决这个问题,我们可以在应用FFT之前简单地翻转g(水平和垂直方向都要翻转)。

3. 锚点位置

使用卷积定理时,假定锚点位置为填充的g的左上角。同样,我不会详细介绍这个问题,但这是数学计算的原理。


第二和第三点可能更容易通过一个例子来理解。假设你使用以下的g

[1 2 3]
[4 5 6]
[7 8 9]

代替g_new的是
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 2 3 0 0]
[0 0 4 5 6 0 0]
[0 0 7 8 9 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]

实际上应该是

[5 4 0 0 0 0 6]
[2 1 0 0 0 0 3]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[8 7 0 0 0 0 9]

我们需要将卷积核水平和垂直翻转,然后进行圆形移位,以使卷积核的中心位于左上角。


我最终重写了大部分您的代码并稍微泛化了一下。最复杂的操作是正确定义g_new。我决定使用meshgrid和模算术来同时翻转和移位索引。如果您对此有任何疑问,请留言,我会尽力澄清。

import torch
import torch.nn.functional as F

def conv2d_pyt(f, g):
    assert len(f.size()) == 2
    assert len(g.size()) == 2

    f_new = f.unsqueeze(0).unsqueeze(0)
    g_new = g.unsqueeze(0).unsqueeze(0)

    pad_y = (g.size(0) - 1) // 2
    pad_x = (g.size(1) - 1) // 2

    fcg = F.conv2d(f_new, g_new, bias=None, padding=(pad_y, pad_x))
    return fcg[0, 0, :, :]

def conv2d_fft(f, g):
    assert len(f.size()) == 2
    assert len(g.size()) == 2

    # in general not necessary that inputs are odd shaped but makes life easier
    assert f.size(0) % 2 == 1
    assert f.size(1) % 2 == 1
    assert g.size(0) % 2 == 1
    assert g.size(1) % 2 == 1

    size_y = f.size(0) + g.size(0) - 1
    size_x = f.size(1) + g.size(1) - 1

    f_new = torch.zeros((size_y, size_x))
    g_new = torch.zeros((size_y, size_x))

    # copy f to center
    f_pad_y = (f_new.size(0) - f.size(0)) // 2
    f_pad_x = (f_new.size(1) - f.size(1)) // 2
    f_new[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x] = f

    # anchor of g is 0,0 (flip g and wrap circular)
    g_center_y = g.size(0) // 2
    g_center_x = g.size(1) // 2
    g_y, g_x = torch.meshgrid(torch.arange(g.size(0)), torch.arange(g.size(1)))
    g_new_y = (g_y.flip(0) - g_center_y) % g_new.size(0)
    g_new_x = (g_x.flip(1) - g_center_x) % g_new.size(1)
    g_new[g_new_y, g_new_x] = g[g_y, g_x]

    # take fft of both f and g
    F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
    F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)

    # complex multiply
    FxG_real = F_f[:, :, 0] * F_g[:, :, 0] - F_f[:, :, 1] * F_g[:, :, 1]
    FxG_imag = F_f[:, :, 0] * F_g[:, :, 1] + F_f[:, :, 1] * F_g[:, :, 0]
    FxG = torch.stack([FxG_real, FxG_imag], dim=2)

    # inverse fft
    fcg = torch.irfft(FxG, signal_ndim=2, onesided=False)

    # crop center before returning
    return fcg[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x]


# calculate f*g
f = torch.randn(11, 7)
g = torch.randn(5, 3)

fcg_pyt = conv2d_pyt(f, g)
fcg_fft = conv2d_fft(f, g)

avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()

print('Average difference:', avg_diff)

这给了我

Average difference: 4.6866085767760524e-07

这非常接近零。我们没有得到精确的零是因为浮点误差。


有哪些好的资源可以学习更多关于这个问题,特别是关于循环移位和零填充卷积核的? - Kiran
1
@Kiran 1) 如果信号在时间上是周期性的,则其频率离散;如果信号在频率上是离散的,则其时间周期性。因此,将从离散时间到离散频率的DFT假定信号在时间上是周期性的。这就解释了为什么所有的移位都是循环的。2)锚点位置源于将DFT的输入解释为以t = 0开始的信号单周期的惯例。你可以在任何DSP书籍中学习这一点,我最喜爱的是Oppenheim的“离散时间信号处理”。 - jodag

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接