如何在pytorch中对权重添加L1或L2正则化

4
在TensorFlow中,我们可以在序列模型中添加L1或L2正则化。我找不到PyTorch中等效的方法。在定义网络时,我们如何为权重添加正则化到PyTorch中:
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
        """ How to add a L1 regularization after a certain hidden layer?? """
        """ OR How to add a L1 regularization after a certain hidden layer?? """
        self.predict = torch.nn.Linear(n_hidden, n_output)   # output layer

    def forward(self, x):
        x = F.relu(self.hidden(x))      # activation function for hidden layer
        x = self.predict(x)             # linear output
        return x

net = Net(n_feature=1, n_hidden=10, n_output=1)     # define the network
# print(net)  # net architecture
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()  # this is for regression mean squared loss

感谢您的帮助。我需要对权重进行正则化,使用L1或L2中的一种,以获得最佳结果。 - zezo
1个回答

5
一般情况下,PyTorch中的L2正则化是通过优化器的weight_decay参数处理的(您也可以为不同层分配不同的参数)。然而,这种机制没有提供L1正则化的支持,除非扩展现有的优化器或编写自定义优化器。
根据TensorFlow文档,他们使用reduce_sum(abs(x))惩罚项实现L1正则化,使用reduce_sum(square(x))惩罚项实现L2正则化。最简单的方法可能就是将这些惩罚项直接添加到用于计算梯度的损失函数中进行训练。
# set l1_weight and l2_weight to non-zero values to enable penalties

# inside the training loop (given input x and target y)
...
pred = net(x)
loss = loss_func(pred, y)

# compute penalty only for net.hidden parameters
l1_penalty = l1_weight * sum([p.abs().sum() for p in net.hidden.parameters()])
l2_penalty = l2_weight * sum([(p**2).sum() for p in net.hidden.parameters()])
loss_with_penalty = loss + l1_penalty + l2_penalty

optimizer.zero_grad()
loss_with_penalty.backward()
optimizer.step()

# The pre-penalty loss is the one we ultimately care about
print('loss:', loss.item())

1
在Python中,正则化是什么意思?不理解网上的定义。 - SmilingMouse
3
我能想到的最直观的解释是,它是机器学习中的一种技术,如果模型参数“过于偏离”某些预定义状态,就会对模型进行惩罚。从某种意义上说,这限制了模型的复杂性,并有助于避免过度拟合等问题。 - jodag
1
.backward()在计算导数时是否考虑惩罚项? - Atif Ali

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接