在numpy中,我们使用ndarray.reshape()
来重新整形一个数组。
我注意到在pytorch中,人们使用torch.view(...)
来完成相同的功能,但同时也存在一个torch.reshape(...)
。
因此,我想知道它们之间的区别以及何时应该使用它们中的任意一个?
torch.view
已经存在很长一段时间了。它将返回一个具有新形状的张量。返回的张量将与原始张量共享底层数据。请参阅此处的文档。
另一方面,似乎torch.reshape
在版本0.4中最近被引入。根据文档,此方法将
返回一个与输入具有相同数据和元素数量但具有指定形状的张量。如果可能,返回的张量将是输入的视图。否则,它将是一个副本。具有连续输入和兼容跨度的输入可以重新整形而无需复制,但您不应该依赖于复制或查看行为。
这意味着torch.reshape
可能返回原始张量的副本或视图。您不能保证返回视图或副本。根据开发人员所说:
另一个区别是,如果需要副本,请使用clone(),如果需要相同的存储,请使用view()。reshape()的语义是它可能共享存储,也可能不共享,并且事先不知道。
reshape()
可以操作连续和非连续张量,而 view()
只能操作连续张量。此外,有关 contiguous
的含义,请参见 此处。torch.view
和torch.reshape
都用于重塑张量,但它们之间存在以下差异。
torch.view
仅创建原始张量的视图。新张量将始终与原始张量共享其数据。这意味着如果更改原始张量,则重塑后的张量也会更改,反之亦然。>>> z = torch.zeros(3, 2)
>>> x = z.view(2, 3)
>>> z.fill_(1)
>>> x
tensor([[1., 1., 1.],
[1., 1., 1.]])
torch.view
对两个张量的形状施加一些连续性约束[docs]。通常情况下,这不是一个问题,但有时即使两个张量的形状兼容,torch.view
也会抛出错误。以下是一个著名的反例。>>> z = torch.zeros(3, 2)
>>> y = z.t()
>>> y.size()
torch.Size([2, 3])
>>> y.view(6)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible with input tensor's
size and stride (at least one dimension spans across two contiguous subspaces).
Call .contiguous() before .view().
torch.reshape
没有强制要求连续性约束,但也不能保证数据共享。新张量可能是原始张量的视图,也可能是全新的张量。>>> z = torch.zeros(3, 2)
>>> y = z.reshape(6)
>>> x = z.t().reshape(6)
>>> z.fill_(1)
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
>>> y
tensor([1., 1., 1., 1., 1., 1.])
>>> x
tensor([0., 0., 0., 0., 0., 0.])
TL;DR:
如果您只想重新塑造张量,请使用torch.reshape
。如果您还关心内存使用情况,并希望确保两个张量共享相同的数据,请使用torch.view
。
x
和y
都是连续的)。也许可以澄清一下?或许对于何时进行reshape并复制的注释会有帮助? - RMurphyx
和y
是连续的,但我们关心的是z
和z.t()
。z
是连续的,所以y
和z
共享相同的数据。z
和z.t()
共享相同的数据,但z.t()
不是连续的,因此z.t()
和x
不共享相同的数据。因此,x
和y
不共享相同的数据。 - Willa = torch.arange(8).reshape(2, 4)
内存分配如下(它是C连续的,即行存储在一起):
stride()函数返回在每个维度中前进到下一个元素所需的字节数:a.stride()
(4, 1)
a.view(4,2)
底层数据分配没有改变,张量仍然是C连续的。
a.view(4, 2).stride()
(2, 1)
a.t().is_contiguous()
False
虽然它不是连续的,但步幅信息足以在张量上进行迭代。a.t().stride()
(1, 4)
a.t().view(2, 4)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
a.t().view(2, 2, 2)
a.t().view(2, 2, 2).stride()
(2, 1, 4)
根据文档:
要查看一个张量,新的视图大小必须与其原始大小和步长兼容,即每个新的视图维度必须是原始维度的子空间,或者仅跨越满足以下连续性条件的原始维度 d、d+1、…、d+k:
stride[i]=stride[i+1]×size[i+1]
这是因为在应用 view(2, 2, 2) 后的前两个维度是转置的第一个维度的子空间。
有关连续性的更多信息,请参阅我在此主题中的回答。
连续
的含义,即指在一个行中是否索引所有下一个数字是连续的。顺便说一下,在b.t().is_contiguous()
中有一个小错别字,可能应该是a.t().is_contiguous()
,仍然感谢! - Wade WangTensor.reshape()
更加健壮。它适用于任何张量,而 Tensor.view()
仅适用于具有 t.is_contiguous()==True
的张量 t
。
关于非连续和连续的解释是另外一个故事,但如果您调用 t.contiguous()
使张量 t
连续,就可以在不出错的情况下调用 view()
。
我认为这里的答案在技术上是正确的,但reshape
存在的另一个原因。通常认为pytorch
比其他框架更方便,因为它更接近python
和numpy
。有趣的是,这个问题涉及到numpy
。
让我们来看看pytorch
中的size
和shape
。 size
是一个函数,所以您可以像x.size()
这样调用它。pytorch
中的shape
不是一个函数。在numpy
中,您有shape
,它不是一个函数 - 您使用x.shape
。因此,在pytorch
中获取两者都很方便。如果您来自numpy
,那么使用相同的函数会很好。