PyTorch中reshape和view有什么区别?

230

在numpy中,我们使用ndarray.reshape()来重新整形一个数组。

我注意到在pytorch中,人们使用torch.view(...)来完成相同的功能,但同时也存在一个torch.reshape(...)

因此,我想知道它们之间的区别以及何时应该使用它们中的任意一个?

5个回答

238

torch.view 已经存在很长一段时间了。它将返回一个具有新形状的张量。返回的张量将与原始张量共享底层数据。请参阅此处的文档

另一方面,似乎torch.reshape在版本0.4中最近被引入。根据文档,此方法将

返回一个与输入具有相同数据和元素数量但具有指定形状的张量。如果可能,返回的张量将是输入的视图。否则,它将是一个副本。具有连续输入和兼容跨度的输入可以重新整形而无需复制,但您不应该依赖于复制或查看行为。

这意味着torch.reshape 可能返回原始张量的副本或视图。您不能保证返回视图或副本。根据开发人员所说:

如果需要副本,请使用clone(),如果需要相同的存储,请使用view()。reshape()的语义是它可能共享存储,也可能不共享,并且事先不知道。

另一个区别是,reshape() 可以操作连续和非连续张量,而 view() 只能操作连续张量。此外,有关 contiguous 的含义,请参见 此处

70
强调torch.view只能用于连续的张量,而torch.reshape可以同时用于非连续张量和连续张量可能会有帮助。 - p13rr0m
8
“contiguous”在这里是指存储在连续内存中的张量,还是指其他东西? - gokul_uf
5
@gokul_uf,是的,你可以查看这里写的答案:https://dev59.com/0FUM5IYBdhLWcg3wY_a2 - MBT
1
view()可以操作非连续的张量。请见我的答案中的示例。 - Pierre
@pierrom 不,我认为MBT的评论可能有点令人困惑。这里的表达“连续”的意思并不是指数据是否存储在连续的内存块中。即使一个PyTorch张量不是“连续的”,元素也是按照连续的内存块排列的。这里的“连续”表达式与PyTorch查看张量时元素的顺序有关。 - starriet
显示剩余5条评论

96
虽然torch.viewtorch.reshape都用于重塑张量,但它们之间存在以下差异。
  1. 如名称所示,torch.view仅创建原始张量的视图。新张量将始终与原始张量共享其数据。这意味着如果更改原始张量,则重塑后的张量也会更改,反之亦然。
>>> z = torch.zeros(3, 2)
>>> x = z.view(2, 3)
>>> z.fill_(1)
>>> x
tensor([[1., 1., 1.],
        [1., 1., 1.]])

为了确保新张量始终与原始张量共享数据,torch.view对两个张量的形状施加一些连续性约束[docs]。通常情况下,这不是一个问题,但有时即使两个张量的形状兼容,torch.view也会抛出错误。以下是一个著名的反例。
>>> z = torch.zeros(3, 2)
>>> y = z.t()
>>> y.size()
torch.Size([2, 3])
>>> y.view(6)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible with input tensor's
size and stride (at least one dimension spans across two contiguous subspaces).
Call .contiguous() before .view().
  1. torch.reshape没有强制要求连续性约束,但也不能保证数据共享。新张量可能是原始张量的视图,也可能是全新的张量。
>>> z = torch.zeros(3, 2)
>>> y = z.reshape(6)
>>> x = z.t().reshape(6)
>>> z.fill_(1)
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])
>>> y
tensor([1., 1., 1., 1., 1., 1.])
>>> x
tensor([0., 0., 0., 0., 0., 0.])

TL;DR:
如果您只想重新塑造张量,请使用torch.reshape。如果您还关心内存使用情况,并希望确保两个张量共享相同的数据,请使用torch.view


3
也许只有我一个人这样想,但我认为连续性是决定reshape是否分享数据的关键因素。通过我的实验,似乎并非如此。(您上面的xy都是连续的)。也许可以澄清一下?或许对于何时进行reshape并复制的注释会有帮助? - RMurphy
1
xy是连续的,但我们关心的是zz.t()z是连续的,所以yz共享相同的数据。zz.t()共享相同的数据,但z.t()不是连续的,因此z.t()x不共享相同的数据。因此,xy不共享相同的数据。 - Will

35
view()会尝试改变张量的形状,同时保持底层数据的分配不变,因此数据将在两个张量之间共享。如果需要,reshape()将创建一个新的底层内存分配。
让我们创建一个张量:
a = torch.arange(8).reshape(2, 4)

initial 2D tensor

内存分配如下(它是C连续的,即行存储在一起):

initial 2D tensor's memory allocation

stride()函数返回在每个维度中前进到下一个元素所需的字节数:
a.stride()
(4, 1)

我们希望它的形状变为(4,2),我们可以使用视图:
a.view(4,2)

after view to switch the dimensions

底层数据分配没有改变,张量仍然是C连续的。

memory allocation after switch

a.view(4, 2).stride()
(2, 1)

让我们试试 a.t()。Transpose() 不会修改底层内存分配,因此 a.t() 不是连续的。
a.t().is_contiguous()
False

after transpose

memory allocation after transpose

虽然它不是连续的,但步幅信息足以在张量上进行迭代。
a.t().stride()
(1, 4)

view()不再起作用了。
a.t().view(2, 4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

以下是我们想要通过使用view(2, 4)获得的形状:

after transpose and reshape

内存分配会是什么样子?

memory allocation without reshape

步幅可能是(2,4),但是在到达末尾后,我们必须回到张量的开头。这样行不通。
在这种情况下,reshape()将创建一个具有不同内存分配的新张量,以使转置连续。

memory allocation with reshape or contiguous

请注意,我们可以使用view函数来分割转置的第一个维度。 与被接受的答案和其他答案所说的不同,view()函数可以操作非连续的张量!
a.t().view(2, 2, 2)

after transpose and view 2, 2, 2

memory allocation after transpose

a.t().view(2, 2, 2).stride()
(2, 1, 4)

根据文档

要查看一个张量,新的视图大小必须与其原始大小和步长兼容,即每个新的视图维度必须是原始维度的子空间,或者仅跨越满足以下连续性条件的原始维度 d、d+1、…、d+k:
stride[i]=stride[i+1]×size[i+1]

这是因为在应用 view(2, 2, 2) 后的前两个维度是转置的第一个维度的子空间。

有关连续性的更多信息,请参阅我在此主题中的回答


2
图示及其颜色深浅帮助我理解连续的含义,即指在一个行中是否索引所有下一个数字是连续的。顺便说一下,在b.t().is_contiguous()中有一个小错别字,可能应该是a.t().is_contiguous(),仍然感谢! - Wade Wang
感谢您的评论并指出了拼写错误!现在已经修复。 - Pierre
1
步幅应该是(2,4),对吗? - undefined
1
"但是当我们到达张量的末尾时,我们将不得不回到张量的开头" - 如果有模运算(对元素数量取模),那么这样的步长应该是正确的,对吗?我在想是否允许这样做会有意义。 - undefined
是的,你说得对。谢谢。 - undefined

18

Tensor.reshape() 更加健壮。它适用于任何张量,而 Tensor.view() 仅适用于具有 t.is_contiguous()==True 的张量 t

关于非连续和连续的解释是另外一个故事,但如果您调用 t.contiguous() 使张量 t 连续,就可以在不出错的情况下调用 view()


0

我认为这里的答案在技术上是正确的,但reshape存在的另一个原因。通常认为pytorch比其他框架更方便,因为它更接近pythonnumpy。有趣的是,这个问题涉及到numpy

让我们来看看pytorch中的sizeshapesize是一个函数,所以您可以像x.size()这样调用它。pytorch中的shape不是一个函数。在numpy中,您有shape,它不是一个函数 - 您使用x.shape。因此,在pytorch中获取两者都很方便。如果您来自numpy,那么使用相同的函数会很好。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接