model.train()在PyTorch中的作用是将模型设置为训练模式。

214

nn.Module中是否调用forward()? 我认为当我们调用模型时,forward方法正在被使用。 那我们为什么需要指定train()呢?


3
现在PyTorch里面有一份文档:https://pytorch.org/docs/stable/generated/torch.nn.Module.html,你可以查看这个文档,我认为它描述得非常清楚。其他的库或框架可能会缺乏文档支持,但是我认为PyTorch的官方文档非常好。 - Konstantin Burlachenko
也许"configure_training"或者"set_training_mode"会是这个函数更好的命名。 - Rexcirus
1
它通过执行self.train(False)简单地递归地为所有模块更改self.training,实际上这就是self.train所做的,递归地将标志更改为true。请参见代码:https://github.com/pytorch/pytorch/blob/6e1a5b1196aa0277a2113a4bca75b6e0f2b4c0c8/torch/nn/modules/module.py#L1432 - Charlie Parker
6个回答

287

model.train() 告诉模型你正在训练它。这有助于通知Dropout和BatchNorm等层在训练和评估期间表现不同的设计。例如,在训练模式下,BatchNorm会在每个新批次上更新移动平均值;而在评估模式下,这些更新被冻结。

更多细节: model.train()将模式设置为训练模式 (请参见源代码)。您可以调用model.eval()model.train(mode=False)来告诉模型您正在测试。 虽然人们很容易期望train函数来训练模型,但它实际上只是设置模式。


4
有没有一个标志可以检测模型是否处于评估模式?例如,mdl.is_eval() - Charlie Parker
7
使用model.training标志。在“eval”模式下时,它为“False”。 - Umang Gupta
@UmangGupta - 确实,默认情况下model.training为True,但是如果您查看链接,他们的训练循环在train步骤之后有一个eval步骤 - 在那里他们调用model.eval()。这将使model.training变为False,他们不会重置它。我理解这非常反直觉 - 我也感到困惑。仍在努力理解为什么会发生这种情况。 - Indrajit
2
@UmangGupta- 实际上我刚刚弄清楚了发生了什么。我的model.train()实际上影响了batchnorm和dropout层,从而影响了模型的性能。 - Indrajit
我想知道 model.eval() 如何影响反向传递? - mrgloom
显示剩余3条评论

93

这是nn.Module.train()的代码:

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

这里是nn.Module.eval()的代码:

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

默认情况下,self.training标志被设置为True,即模块默认处于训练模式。当self.trainingFalse时,模块处于相反的状态,即评估模式。
在最常用的层中,只有DropoutBatchNorm关心该标志。

现在还有其他支持 self.training 标志的层吗? - Melike
我想知道 model.eval() 如何影响反向传递? - mrgloom
model.eval()只是一个开关,用于不使用dropout和batch norms。我有一篇很好的PyTorch培训介绍,您可以在其中检查前向和后向传递,以及深入了解PyTorch AD,您可以自信地理解PyTorch AD的详细信息。 - prosti
1
@Melike https://dev59.com/tb_qa4cB1Zd3GeqPPctG - iacob

38
model.train() model.eval()
将模型设置为训练模式,即

BatchNorm层使用每批次的统计信息
Dropout层激活等
将模型设置为评估(推断)模式,即

BatchNorm层使用运行时的统计信息
Dropout层停用等
相当于model.train(False)

注意:这两个函数调用都不会运行前向/反向传递。它们告诉模型在运行时应该如何操作。

这很重要,因为某些模块(层)(例如DropoutBatchNorm)在训练和推断期间设计的行为不同,因此如果在错误的模式下运行模型,则会产生意外的结果。


16

让模型知道你的意图有两种方法,即你想训练模型还是使用模型进行评估。 在 model.train() 的情况下,模型知道它必须学习各层,而当我们使用 model.eval() 时,它表示模型不需要学习任何新内容,模型用于测试。 model.eval() 也是必要的,因为在 pytorch 中,如果我们使用批标准化,并且在测试期间只想传递一张图片,如果未指定 model.eval(),pytorch 会抛出一个错误。


1
考虑以下模型。
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GraphNet(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GraphNet, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.dropout(x, training=self.training) #Look here
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

在这里,dropout的功能在不同的操作模式下会有所不同。正如你所看到的,它只在self.training==True时起作用。因此,当你输入model.train()时,模型的前向函数将执行dropout,否则不会(比如当model.eval()model.train(mode=False)时)。

1

目前的官方文档中如下所述:

这仅对某些模块产生影响。有关特定模块在训练/评估模式下的行为是否受到影响的详细信息,请参阅其文档,例如Dropout、BatchNorm等。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接