PyTorch中的BatchNorm2d的running_mean/running_var是什么意思?

3

我想知道从nn.BatchNorm2d中调用的running_meanrunning_var是什么。

以下是示例代码,其中bn代表nn.BatchNorm2d

vector = torch.cat([
    torch.mean(self.conv3.bn.running_mean).view(1), torch.std(self.conv3.bn.running_mean).view(1),
    torch.mean(self.conv3.bn.running_var).view(1), torch.std(self.conv3.bn.running_var).view(1),
    torch.mean(self.conv5.bn.running_mean).view(1), torch.std(self.conv5.bn.running_mean).view(1),
    torch.mean(self.conv5.bn.running_var).view(1), torch.std(self.conv5.bn.running_var).view(1)
])

我在Pytorch官方文档和用户社区中都无法理解running_meanrunning_var的含义。

nn.BatchNorm2.running_meannn.BatchNorm2.running_var是什么意思?


2
PyTorch 实现了 BatchNorm 论文,我建议你阅读一下。 - Alexey Larionov
1个回答

6
从原始的Batchnorm论文中:
批量归一化:通过减少内部协变量位移加速深度网络训练,作者为Sergey Ioffe和Christian Szegedy,ICML'2015。
您可以在算法1中看到如何测量给定批次的统计信息。

enter image description here

然而,在不同的批次之间,保留在内存中的是运行统计量,即在每个批次推理时迭代测量的统计量。计算运行均值和运行方差在nn.BatchNorm2d的文档页面中有很好的解释:

enter image description here

默认情况下,momentum系数被设置为0.1,它调节当前批次统计数据对运行统计数据的影响程度:
  • 当系数越接近1时,新的运行统计数据越接近于当前批次的统计数据;

  • 当系数越接近0时,当前批次的统计数据对更新新的运行统计数据的贡献越小。

值得注意的是,Batchnorm2d在空间维度上应用,*此外*,当然还要应用在批次维度上。给定一个形状为(b, c, h, w)的批次,它将计算(b, h, w)上的统计数据。这意味着运行统计数据的形状为(c,),即输入通道中有多少个统计组件(均值和方差)就有多少个统计组件。
以下是一个最简示例:
>>> bn = nn.BatchNorm2d(10)
>>> x = torch.rand(2,10,2,2)

由于在BatchNorm2d默认情况下track_running_stats设置为True,因此在训练模式下进行推断时它将跟踪运行统计信息。

运行平均值和方差分别初始化为零和一。

>>> running_mean, running_var = torch.zeros(x.size(1)),torch.ones(x.size(1))

让我们在训练模式下对 bn 进行推断并检查其运行统计数据:
>>> bn(x)
>>> bn.running_mean, bn.running_var
(tensor([0.0650, 0.0432, 0.0373, 0.0534, 0.0476, 
         0.0622, 0.0651, 0.0660, 0.0406, 0.0446]),
 tensor([0.9027, 0.9170, 0.9162, 0.9082, 0.9087, 
         0.9026, 0.9136, 0.9043, 0.9126, 0.9122]))

现在让我们手动计算这些统计数据:
>>> (1-momentum)*running_mean + momentum*xmean
tensor([[0.0650, 0.0432, 0.0373, 0.0534, 0.0476, 
         0.0622, 0.0651, 0.0660, 0.0406, 0.0446]])

>>> (1-momentum)*running_var + momentum*xvar
tensor([[0.9027, 0.9170, 0.9162, 0.9082, 0.9087, 
         0.9026, 0.9136, 0.9043, 0.9126, 0.9122]])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接