使用torch.no_grad()的目的是什么?

8

考虑使用PyTorch实现线性回归的以下代码:

X是训练集的输入,Y是输出,w是需要优化的参数
import torch

X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)

def forward(x):
    return w * x

def loss(y, y_pred):
    return ((y_pred - y)**2).mean()

print(f'Prediction before training: f(5) = {forward(5).item():.3f}')

learning_rate = 0.01
n_iters = 100

for epoch in range(n_iters):
    # predict = forward pass
    y_pred = forward(X)

    # loss
    l = loss(Y, y_pred)

    # calculate gradients = backward pass
    l.backward()

    # update weights
    #w.data = w.data - learning_rate * w.grad
    with torch.no_grad():
        w -= learning_rate * w.grad
    
    # zero the gradients after updating
    w.grad.zero_()

    if epoch % 10 == 0:
        print(f'epoch {epoch+1}: w = {w.item():.3f}, loss = {l.item():.8f}')

`'with'` 块是什么?w 的 requires_grad 参数已经设置为 True。那么它为什么被放在 `with torch.no_grad()` 块下面呢?
2个回答

7
< p > requires_grad 参数告诉 PyTorch 我们想要为这些值计算梯度。然而, with torch.no_grad() 告诉 PyTorch 不要计算梯度,并且程序在此处明确使用它(与大多数神经网络一样),以便在更新权重时不更新梯度,因为那会影响反向传播。


3

在更新权重时,没有必要跟踪梯度;这就是为什么在任何优化器实现的step方法中都会找到一个修饰符(@torch.no_grad())。

"使用torch.no_grad"代码块表示执行这些行而不跟踪梯度。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接