PyTorch是否会在nn.Linear中自动应用softmax函数?

16

pytorch中,分类网络模型定义如下:

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
        self.out = torch.nn.Linear(n_hidden, n_output)   # output layer

    def forward(self, x):
        x = F.relu(self.hidden(x))      # activation function for hidden layer
        x = self.out(x)
        return x

这里使用了softmax吗?在我看来,应该是这样的:

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
        self.relu =  torch.nn.ReLu(inplace=True)
        self.out = torch.nn.Linear(n_hidden, n_output)   # output layer
        self.softmax = torch.nn.Softmax(dim=n_output)
    def forward(self, x):
        x = self.hidden(x)      # activation function for hidden layer
        x = self.relu(x)
        x = self.out(x)
        x = self.softmax(x)
        return x

我知道F.relu(self.relu(x))也应用了ReLU,但是第一个代码块没有应用softmax,对吗?


是的,线性函数不会自动应用softmax。 - unlut
@unlut 谢谢,第二个代码块看起来对吗? - yujuezhao
我认为它是正确的。 - unlut
8
相关提示:如果您正在使用nn.CrossEntropyLoss,那么它会应用log-softmax,然后是nll-loss。您可能希望确保不要两次应用softmax,因为softmax 不是 幂等的。 - jodag
@jodag 谢谢!!! 我在 @dennlinger 的回答下还有更多问题。希望也能听到您的建议! - yujuezhao
1个回答

11

承接@jodag在评论中提到的内容,并稍作拓展来形成一个完整的回答:

不是,PyTorch不会自动应用softmax,你可以在任何时候按照需要使用torch.nn.Softmax()但是,softmax存在数值稳定性问题,我们要尽可能避免这种情况。一种解决方法是使用log-softmax,但这往往比直接计算慢。

特别是当我们使用负对数似然作为损失函数时(在PyTorch中,这是torch.nn.NLLLoss),我们可以利用(log-)softmax+NLLL的导数实际上是数学上非常好和简单的事实,这就是为什么将两者合并成一个单独的函数/元素是有意义的。结果是torch.nn.CrossEntropyLoss。再次注意,这仅直接应用于您网络的最后一层,任何其他计算都不受此影响。

4
如果我理解正确的话,更好的做法是将nn.CrossEntropyLoss用作最后一层nn.Linear()的输出损失函数,而不是直接使用nn.Softmax()。这样正确吗? - yujuezhao
另一个问题随之而来,nn.Softmax()的输出可以被视为某个类别的概率,而nn.Linear()的所有输出之和并不保证等于1。这是否会失去最终输出的意义? - yujuezhao
2
回答你的第一个评论:你并没有真正地用损失函数替换任何层,而是将当前的损失函数(应该是nn.NLLLoss)替换为不同的损失函数,同时删除最后一个nn.Softmax()。我认为你已经有了正确的想法。第二个问题:由于你的损失函数仍然“应用”对数softmax(或者至少你的导数是基于它的),所以解释仍然成立。如果你在任何其他方面使用输出,例如在推理期间,那么当然必须在那种情况下重新应用softmax。 - dennlinger

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接