如何在pytorch中限制参数的范围?

3

在PyTorch中,通常情况下,模型的参数没有严格的限制,但如果我想让它们保持在[0,1]范围内呢?有没有一种方式可以阻止参数更新到该范围之外?

1个回答

8
某些生成对抗网络(其中一些需要将鉴别器的参数限制在某个范围内)中使用的技巧是,在每次梯度更新后将值夹紧。例如:
model = YourPyTorchModule()

for _ in range(epochs):
    loss = ...
    optimizer.step()
    for p in model.parameters():
        p.data.clamp_(-1.0, 1.0)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接