PyTorch函数中的下划线后缀是什么意思?

19
在PyTorch中,一个张量的许多方法存在两个版本——一个带有下划线后缀,一个没有。如果我尝试它们,它们似乎做相同的事情。
In [1]: import torch

In [2]: a = torch.tensor([2, 4, 6])

In [3]: a.add(10)
Out[3]: tensor([12, 14, 16])

In [4]: a.add_(10)
Out[4]: tensor([12, 14, 16])

torch.add和torch.add_之间的区别是什么?

  • torch.subtorch.sub_之间的区别是什么?
  • ...以此类推?
3个回答

12

你已经回答了你自己的问题,即下划线在PyTorch中表示原地操作。但是我想简要指出为什么原位操作可能会有问题:

  • 首先,在大多数情况下,PyTorch网站建议不要使用原地操作。除非在内存压力较大的情况下,在大多数情况下不使用原地操作更有效率
    https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd

  • 其次,在使用原地操作时可能存在计算梯度的问题:

    每个张量都保留一个版本计数器(version counter),它在任何操作中被标记为脏时都会增加。当 Function 保存反向传播所需的任何张量时,它们所包含的 Tensor 的版本计数器也会被保存。一旦你访问 self.saved_tensors,就会检查它,如果它大于保存的值,则会引发错误。这确保了如果你正在使用原位函数并没有看到任何错误,则可以确信计算出的梯度是正确的。 同上面的来源。

这里是一个来自你回答中稍微修改的例子:

首先是原地版本:

import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add_(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)

这导致了以下错误:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-c38b252ffe5f> in <module>
      2 a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
      3 adding_tensor = torch.rand(3)
----> 4 b = a.add_(adding_tensor)
      5 c = torch.sum(b)
      6 c.backward()

RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

其次是非原地版本:


import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)

这可以完美地工作 - 输出:

<SumBackward0 object at 0x7f06b27a1da0>
因此,我的建议是小心使用PyTorch中的原位操作。

6
根据文档,以下划线结尾的方法会原地更改张量。这意味着通过执行操作不会分配新内存,通常提高性能, 但在PyTorch中可能导致问题和更差的性能
In [2]: a = torch.tensor([2, 4, 6])

tensor.add():

In [3]: b = a.add(10)

In [4]: a is b
Out[4]: False # b is a new tensor, new memory was allocated

tensor.add_():

In [3]: b = a.add_(10)

In [4]: a is b
Out[4]: True # Same object, no new memory was allocated

请注意,运算符++=也是两种不同的实现+使用.add()创建一个新的张量,而+=使用.add_()修改张量。

In [2]: a = torch.tensor([2, 4, 6])

In [3]: id(a)
Out[3]: 140250660654104

In [4]: a += 10

In [5]: id(a)
Out[5]: 140250660654104 # Still the same object, no memory allocation was required

In [6]: a = a + 10

In [7]: id(a)
Out[7]: 140250649668272 # New object was created

3
你有没有关于这个陈述的PyTorch代码:“就地[...]操作[...]可以显著提高性能[...]因此应优先使用就地方法”?据我所知,对于PyTorch来说恰恰相反 - 在大多数情况下强烈不建议使用PyTorch中的就地操作。https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd - MBT

0

这个回答对现有的回答没有任何补充。 - Sören

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接