Pytorch的doc方法用于注册缓存区,该缓存区不应视为模型参数。例如,BatchNorm的running_mean
不是参数,而是持久状态的一部分。
正如您已经观察到的那样,在训练过程中使用SGD学习和更新模型参数。
但是,有时存在其他数量,它们是模型的“状态”并且应该:
- 作为state_dict
的一部分保存。
- 与模型的其余参数一起移动到cuda()
或cpu()
。
- 与模型的其余参数一起转换为float
/half
/double
。
将这些“参数”注册为模型的buffer
允许pytorch跟踪它们并像常规参数一样保存它们,但防止pytorch使用SGD机制更新它们。
在_BatchNorm
模块中可以找到缓冲区的一个示例,其中running_mean
、running_var
和num_batches_tracked
被注册为缓冲区,并通过累积转发到该层的数据的统计信息进行更新。这与使用常规SGD优化学习数据的仿射变换的weight
和bias
参数形成对比。
在创建一个模块 (nn.Module
) 的参数和缓冲区时,你需要使用 register_parameter()
注册一个新的命名张量作为参数。例如,你已经有了线性层 nn.Linear
,其中包括了 weight
和 bias
参数。但是如果你需要一个新的参数,则需要注册一个新的命名参数。
当你注册一个新的参数时,它将出现在 module.parameters()
迭代器中,但当你注册一个缓冲区时,则不会出现在其中。
两者的区别:
缓冲区 是一种被命名的张量,不像参数那样在每一步更新梯度。 对于缓冲区,你可以自定义它的逻辑(完全取决于你)。
好处是,当你保存模型时,所有的参数和缓冲区都会被保存下来,当你把模型移动到或从 CUDA 上移动时,参数和缓冲区也会随之移动。
register_buffer(name, tensor, persistent=False)
注册的缓冲区的用例是什么,即它们甚至不是state_dict
的一部分? - bluenote10nn.Module
中有一个常量张量(例如位置嵌入等),您希望确保每当您调用.to()
时,此张量都会移动到适当的设备并转换为正确的dtype
... - Shai