如何在pytorch中进行梯度裁剪?

105

如何在PyTorch中正确执行梯度裁剪?

我的模型存在梯度爆炸的问题。


https://discuss.pytorch.org/t/proper-way-to-do-gradient-clipping/191 - p13rr0m
8
谢谢@pierrom。我自己找到了那篇帖子。不过我想在这里问一下,这样可以让之后的人节省时间,不必读完所有的讨论(我自己也还没有读完),只需像在stackoverflow上那样得到一个快速的回答。去论坛寻找答案让我想起了1990年的事情。如果没有其他人在我之前发布答案,我会在找到答案后发出来的。 - Gulzar
4个回答

179

这里是一个更完整的示例,来自这里:

optimizer.zero_grad()        
loss, hidden = model(data, hidden, targets)
loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()

3
为什么这个更完整?我看到了更多的投票,但不太明白为什么这更好。你能解释一下吗? - Gulzar
19
这只是一个常见模式,可以在loss.backward()和optimizer.step()之间插入torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)来对梯度进行裁剪。 - Rahul
14
args.clip是什么? - Farhang Amaji
1
在进行前向传递之前或之后调用opt.zero_grad()是否有影响?我猜可能越早清零,内存释放就越快? - Charlie Parker
6
@FarhangAmaji max_norm(剪切阈值)的值来自args(可能是来自argparse模块)。 - vdi
对于 "args.clip",你可以使用 0.01;例如 torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)。 - russian_spy

108

clip_grad_norm_ 替代了被弃用的 clip_grad_norm,其采用了更加一致的语法,在进行原地修改时会在函数名后添加下划线(_)。该函数通过将传递给函数的所有参数连接起来来剪裁梯度的 总体 范数,如文档所述

范数是计算所有梯度的向量拼接后得到的。梯度是原地修改的。

从您的示例中可以看出,您需要使用 clip_grad_value_,它具有类似的语法,并且也可以原地修改梯度:

clip_grad_value_(model.parameters(), clip_value)

另一种选择是注册一个 反向钩子。它以当前梯度作为输入,并可能返回一个张量,该张量将替换先前的梯度,即修改先前的梯度。每次计算完梯度后,都会调用此钩子,即在注册钩子后不需要手动剪辑:

for p in model.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value))

19
这里值得一提的是,这两种方法并不相等。后一种使用注册钩子的方法明显是大多数人想要的。这两种方法之间的区别在于,后一种方法在反向传播期间裁剪梯度,并且第一种方法则在整个反向传播完成后才裁剪梯度。 - c0mr4t
6
为什么我们要在反向传播期间剪辑梯度而不是之后呢?试图理解为什么后者比前者更可取。 - NikSp
7
如果在反向传播期间进行剪裁,则被剪裁的梯度会向上游层级传播。否则,原始梯度会向上传播,这可能会使得上游层级的梯度饱和(如果在反向传播之后进行剪裁)。如果所有层级的梯度都饱和在阈值(剪裁)值上,这可能会阻止收敛。@NikSp - a_guest
1
你能详细说明如何确保后者进行L2范数剪裁吗?目前看起来它只是剪裁了单个元素的绝对值。此外,register_hook只在梯度上工作吗?因为我本来期望像param.grad这样的东西。谢谢。 - sachinruk
虽然注册一个钩子是一个不错的选择,但是答案中的钩子似乎没有应用规范剪辑。它只是剪辑了各个元素,而不是梯度元素的规范。 - Shiania White
@a_guest 当启用AMP时,如何使用hook剪裁梯度?如果我简单地按照hkchengrex的建议“unscale”梯度,会起作用吗? - undefined

15

阅读论坛讨论后得出如下结论:

clipping_value = 1 # arbitrary value of your choosing
torch.nn.utils.clip_grad_norm(model.parameters(), clipping_value)

我相信这不仅仅只是这段代码片段所涵盖的。


13

如果您正在使用自动混合精度(AMP),在剪辑之前需要做更多工作,因为AMP会缩放梯度:

optimizer.zero_grad()
loss = model(data, targets)
scaler.scale(loss).backward()

# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)

# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)

# Updates the scale for next iteration.
scaler.update()

参考文献:https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping

该链接是关于PyTorch的自动混合精度(AMP)中梯度裁剪的示例说明。在训练深度神经网络时,梯度裁剪是一种常用的技术,旨在防止梯度爆炸现象的发生。本链接提供了使用PyTorch执行梯度裁剪的代码示例。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接