PyTorch中register_parameter和register_buffer有什么区别?

31

在训练过程中,模块的参数会发生变化,也就是说,它们是神经网络训练过程中学习到的内容。那么,缓冲区是什么?

它是否也是在神经网络训练时学习到的呢?

2个回答

60

Pytorch的doc方法用于注册缓存区,该缓存区不应视为模型参数。例如,BatchNorm的running_mean不是参数,而是持久状态的一部分。

正如您已经观察到的那样,在训练过程中使用SGD学习和更新模型参数
但是,有时存在其他数量,它们是模型的“状态”并且应该:
- 作为state_dict的一部分保存。
- 与模型的其余参数一起移动到cuda()cpu()
- 与模型的其余参数一起转换为float/half/double
将这些“参数”注册为模型的buffer允许pytorch跟踪它们并像常规参数一样保存它们,但防止pytorch使用SGD机制更新它们。

_BatchNorm模块中可以找到缓冲区的一个示例,其中running_meanrunning_varnum_batches_tracked被注册为缓冲区,并通过累积转发到该层的数据的统计信息进行更新。这与使用常规SGD优化学习数据的仿射变换的weightbias参数形成对比。


2
这引出了一个问题,即通过register_buffer(name, tensor, persistent=False)注册的缓冲区的用例是什么,即它们甚至不是state_dict的一部分? - bluenote10
2
@bluenote10,这很有道理。假设您的nn.Module中有一个常量张量(例如位置嵌入等),您希望确保每当您调用.to()时,此张量都会移动到适当的设备并转换为正确的dtype... - Shai

27

在创建一个模块 (nn.Module) 的参数和缓冲区时,你需要使用 register_parameter() 注册一个新的命名张量作为参数。例如,你已经有了线性层 nn.Linear,其中包括了 weightbias 参数。但是如果你需要一个新的参数,则需要注册一个新的命名参数。

当你注册一个新的参数时,它将出现在 module.parameters() 迭代器中,但当你注册一个缓冲区时,则不会出现在其中。

两者的区别:

缓冲区 是一种被命名的张量,不像参数那样在每一步更新梯度。 对于缓冲区,你可以自定义它的逻辑(完全取决于你)。

好处是,当你保存模型时,所有的参数和缓冲区都会被保存下来,当你把模型移动到或从 CUDA 上移动时,参数和缓冲区也会随之移动。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接