model.eval() 和 model.train() 在 PyTorch 中影响哪些模块?

9
model.eval() 方法会修改某些模块(层),这些模块在训练和推理时具有不同的行为。其中一些模块(层)在文档中列出了部分示例,例如:DropoutBatchNorm等等。请参考特定模块的文档以了解它们在训练/推理模式下的行为是否会受到影响。
是否有哪些模块会受到影响的详尽列表?

我认为...就是这样了吧?我不记得还有其他标准层会改变其行为,但也许我错了,如果这个列表存在的话,我会很快被纠正:)当然,我考虑了所有继承自BatchNorm的层。 - Proko
2个回答

11

除了@iacob提供的信息外:

基类 模块 标准
RNNBase RNN
LSTM
GRU
dropout > 0(默认值:0
Transformer层 Transformer
TransformerEncoder
TransformerDecoder
dropout > 0Transformer 默认值:0.1
Lazy变体 LazyBatchNorm
目前是夜版
合并 PR
track_running_stats=True

2
GroupNorm和LayerNorm不跟踪运行统计信息,并且不受model.eval()的影响。 - Trisoloriansunscreen

7

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接