在nn.Module
中是否调用forward()
? 我认为当我们调用模型时,forward
方法正在被使用。
那我们为什么需要指定train()呢?
model.train()
告诉模型你正在训练它。这有助于通知Dropout和BatchNorm等层在训练和评估期间表现不同的设计。例如,在训练模式下,BatchNorm会在每个新批次上更新移动平均值;而在评估模式下,这些更新被冻结。
更多细节:
model.train()
将模式设置为训练模式
(请参见源代码)。您可以调用model.eval()
或model.train(mode=False)
来告诉模型您正在测试。
虽然人们很容易期望train
函数来训练模型,但它实际上只是设置模式。
mdl.is_eval()
? - Charlie Parkermodel.training
标志。在“eval”模式下时,它为“False”。 - Umang Guptamodel.eval()
如何影响反向传递? - mrgloom这是nn.Module.train()
的代码:
def train(self, mode=True):
r"""Sets the module in training mode."""
self.training = mode
for module in self.children():
module.train(mode)
return self
这里是nn.Module.eval()
的代码:
def eval(self):
r"""Sets the module in evaluation mode."""
return self.train(False)
self.training
标志被设置为True
,即模块默认处于训练模式。当self.training
为False
时,模块处于相反的状态,即评估模式。Dropout
和BatchNorm
关心该标志。self.training
标志的层吗? - Melikemodel.eval()
如何影响反向传递? - mrgloommodel.eval()
只是一个开关,用于不使用dropout和batch norms。我有一篇很好的PyTorch培训介绍,您可以在其中检查前向和后向传递,以及深入了解PyTorch AD,您可以自信地理解PyTorch AD的详细信息。 - prostimodel.train() |
model.eval() |
---|---|
将模型设置为训练模式,即 • BatchNorm 层使用每批次的统计信息• Dropout 层激活等等 |
将模型设置为评估(推断)模式,即 • BatchNorm 层使用运行时的统计信息• Dropout 层停用等 |
相当于model.train(False) 。 |
注意:这两个函数调用都不会运行前向/反向传递。它们告诉模型在运行时应该如何操作。
这很重要,因为某些模块(层)(例如Dropout
、BatchNorm
)在训练和推断期间设计的行为不同,因此如果在错误的模式下运行模型,则会产生意外的结果。
让模型知道你的意图有两种方法,即你想训练模型还是使用模型进行评估。
在 model.train()
的情况下,模型知道它必须学习各层,而当我们使用 model.eval()
时,它表示模型不需要学习任何新内容,模型用于测试。
model.eval()
也是必要的,因为在 pytorch 中,如果我们使用批标准化,并且在测试期间只想传递一张图片,如果未指定 model.eval()
,pytorch 会抛出一个错误。
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GraphNet(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GraphNet, self).__init__()
self.conv1 = GCNConv(num_node_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.dropout(x, training=self.training) #Look here
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
dropout
的功能在不同的操作模式下会有所不同。正如你所看到的,它只在self.training==True
时起作用。因此,当你输入model.train()
时,模型的前向函数将执行dropout,否则不会(比如当model.eval()
或model.train(mode=False)
时)。目前的官方文档中如下所述:
这仅对某些模块产生影响。有关特定模块在训练/评估模式下的行为是否受到影响的详细信息,请参阅其文档,例如Dropout、BatchNorm等。
self.train(False)
简单地递归地为所有模块更改self.training
,实际上这就是self.train
所做的,递归地将标志更改为true。请参见代码:https://github.com/pytorch/pytorch/blob/6e1a5b1196aa0277a2113a4bca75b6e0f2b4c0c8/torch/nn/modules/module.py#L1432 - Charlie Parker