PyTorch首选复制张量的方式

171

在PyTorch中似乎有多种方法可以创建张量的副本,包括

y = tensor.new_tensor(x) #a

y = x.clone().detach() #b

y = torch.empty_like(x).copy_(x) #c

y = torch.tensor(x) #d

根据我执行ad时收到的UserWarning,明确优先使用b。为什么要优先使用b?性能吗?我认为它不太容易阅读。

有关使用c的任何理由,是否支持或反对?


7
“b”的一个优点是明确了“y”不再是计算图的一部分,即不需要梯度,而“c”与前三个不同之处在于“y”仍需要梯度。 - Shihab Shahriar Khan
1
torch.empty_like(x).copy_(x).detach()怎么样 - 这跟 a/b/d 是一样的吗? 我知道这不是一个明智的做法,我只是想了解自动求导的工作原理。我对clone()的文档感到困惑,其中写道:“与 copy_() 不同,此函数记录在计算图中”,这使我认为 copy_() 不需要 grad。 - dkv
3
文档中有一份相当明确的注释:当数据是张量 x 时,new_tensor() 从传递给它的任何东西中读取“数据”,并构造一个叶子变量。 因此,tensor.new_tensor(x) 等价于 x.clone().detach(),而 tensor.new_tensor(x, requires_grad=True) 则等价于 x.clone().detach().requires_grad_(True)。 推荐使用使用 clone() 和 detach() 的等价方式。 - cleros
6
Pytorch '1.1.0' 推荐使用 #b 并在 #d 中显示警告。 - macharya
@ManojAcharya,也许可以考虑将您的评论添加为答案。 - Shagun Sodhani
.clone()本身怎么样? - Charlie Parker
4个回答

175

简述

使用.clone().detach() (或更好的.detach().clone())

如果你首先分离张量,然后再复制它,计算路径不会被复制,反之亦然。因此,.detach().clone() 稍微更有效率。-- PyTorch 论坛

因为这种方式稍微快一些,并且在执行过程中更加明确。


使用perflot,我绘制了复制pytorch张量的各种方法的时间。
y = tensor.new_tensor(x) # method a

y = x.clone().detach() # method b

y = torch.empty_like(x).copy_(x) # method c

y = torch.tensor(x) # method d

y = x.detach().clone() # method e

该图的x轴是创建张量所需要的维度,y轴代表时间。该图是线性比例尺。您可以清楚地看到,tensor()new_tensor()方法相比其他三种方法需要更长的时间。

enter image description here

注意: 在多次运行中,我注意到在b、c、e中,任何一种方法都可能有最低的时间。a和d也是如此。但是,方法b、c、e的时间始终比a和d低。

import torch
import perfplot

perfplot.show(
    setup=lambda n: torch.randn(n),
    kernels=[
        lambda a: a.new_tensor(a),
        lambda a: a.clone().detach(),
        lambda a: torch.empty_like(a).copy_(a),
        lambda a: torch.tensor(a),
        lambda a: a.detach().clone(),
    ],
    labels=["new_tensor()", "clone().detach()", "empty_like().copy()", "tensor()", "detach().clone()"],
    n_range=[2 ** k for k in range(15)],
    xlabel="len(a)",
    logx=False,
    logy=False,
    title='Timing comparison for copying a pytorch tensor',
)

2
愚蠢的问题,但为什么我们需要clone()?否则两个张量指向相同的原始数据吗? - gebbissimo
2
啊,是的,请查看https://discuss.pytorch.org/t/clone-and-detach-in-v0-4-0/16861/2?u=gebbissimo - gebbissimo

24
根据Pytorch文档,#a和#b是等效的。文档还指出:

建议使用clone()和detach()方法进行替代。

因此,如果您想复制张量并从计算图中分离,应该使用:
y = x.clone().detach()

因为这是最干净和最易读的方式。所有其他版本都有一些隐藏的逻辑,而且计算图和梯度传播发生了什么也不是100%清楚。

关于 #c:它似乎对实际完成的工作有点复杂,并且可能会引入一些开销,但我对此不确定。

编辑:由于在评论中询问为什么不只使用 .clone()

来自pytorch docs

与copy_()不同,此函数记录在计算图中。传播到克隆张量的梯度将传播到原始张量。

因此,虽然 .clone()返回数据的副本,但它保留计算图并在其中记录克隆操作。如上所述,这将导致传播到克隆张量的梯度也传播到原始张量。这种行为可能会导致错误,并且不明显。由于这些可能的副作用,如果明确希望此行为,则应仅通过 .clone()克隆张量。为避免这些副作用,添加了 .detach()以从克隆的张量中断计算图。

由于通常对于复制操作,人们希望获得一个干净的副本,它不会导致意外的副作用,因此将张量复制的首选方式是 .clone().detach()


2
为什么需要使用 detach() 函数? - a06e
5
与copy_()不同,这个函数会被记录在计算图中。传播到克隆张量的梯度将传播到原始张量。因此,如果你想要真正复制张量,你需要将它分离(detach),否则你可能会得到一些不想要的梯度更新,而你不知道这些更新来自哪里。 - Nopileos
.clone()本身怎么样? - Charlie Parker
1
我添加了一些文字来解释为什么不能自己克隆。希望这回答了你的问题。 - Nopileos

6

检查张量是否被复制的一个例子:

import torch
def samestorage(x,y):
    if x.storage().data_ptr()==y.storage().data_ptr():
        print("same storage")
    else:
        print("different storage")
a = torch.ones((1,2), requires_grad=True)
print(a)
b = a
c = a.data
d = a.detach()
e = a.data.clone()
f = a.clone()
g = a.detach().clone()
i = torch.empty_like(a).copy_(a)
j = torch.tensor(a) # UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


print("a:",end='');samestorage(a,a)
print("b:",end='');samestorage(a,b)
print("c:",end='');samestorage(a,c)
print("d:",end='');samestorage(a,d)
print("e:",end='');samestorage(a,e)
print("f:",end='');samestorage(a,f)
print("g:",end='');samestorage(a,g)
print("i:",end='');samestorage(a,i)

输出:

tensor([[1., 1.]], requires_grad=True)
a:same storage
b:same storage
c:same storage
d:same storage
e:different storage
f:different storage
g:different storage
i:different storage
j:different storage

如果出现不同的存储方式,张量将被复制。

PyTorch拥有近100种不同的构造函数,因此您还可以添加更多。

如果我需要复制张量,我只需使用copy()命令,这也会复制与自动微分相关的信息,所以如果我需要删除与AD相关的信息,我会使用:

y = x.clone().detach()

使用"if x.untyped_storage().data_ptr()==y.untyped_storage().data_ptr():"代替"if x.storage().data_ptr()==y.storage().data_ptr():"以避免使用已弃用的TypedStorage。 - Tom J

4

Pytorch '1.1.0' 推荐使用 #b,同时会对 #d 发出警告提示。


.clone()本身怎么样? - Charlie Parker
克隆本身也将保留与原始图形相关联的变量。 - macharya

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接