PyTorch中的.flatten()和.view(-1)有什么区别?

29

.flatten().view(-1)都可以在PyTorch中对张量进行扁平化操作。它们之间的区别是什么?

  1. .flatten()是否会复制张量数据?
  2. .view(-1)是否更快?
  3. 是否存在某种情况.flatten()无法使用?

4
我认为默认参数情况下,.flatten().flatten()是相同的,但是.flatten()可以让你传递start_dimend_dim,以实现更复杂的行为。例如,torch.ones(10, 4, 5, 6).flatten(start_dim=1, end_dim=2)将返回形状为(10, 20, 6)的张量。 - adeelh
3个回答

19
除了@adeelh的评论之外,还有另一个区别:torch.flatten()会导致.reshape(),而.reshape().view()之间的差异是:
  • [...] torch.reshape可能返回原始张量的副本或视图。您不能指望它返回视图或副本。

  • 另一个区别是,reshape()可以操作连续和非连续张量,而view()只能操作连续张量。关于连续的含义,请参见此处

背景:

  • 社区要求有一个flatten函数,经过Issue #7743后,该功能在PR #8578中实现。

  • 您可以在此处看到flatten的实现,其中可以看到在return行中调用了.reshape()


13
flatten只是1一种常见的使用方式view的简便别名。
还有其他几种用法: | 功能 | 等效的view逻辑 | | --- | --- | | flatten() | view(-1) | | flatten(start, end) | view(*t.shape[:start], -1, *t.shape[end+1:]) | | squeeze() | view(*[s for s in t.shape if s != 1]) | | unsqueeze(i) | view(*t.shape[:i-1], 1, *t.shape[i:]) |
请注意,flatten允许您压缩特定连续维度的子集,使用start_dimend_dim参数。
实际上,在底层,reshape 是表面等效的。

2
首先,.view() 只适用于连续数据,而.flatten() 适用于 连续非连续数据。像transpose这样生成非连续数据的函数,可以通过.flatten()进行操作,但不能使用.view()

关于数据的复制,当它们处理连续数据时,.view().flatten()都不会复制数据。然而,在非连续数据的情况下,.flatten()首先将数据复制到连续内存中,然后更改维度。新张量上的任何更改都不会影响原始张量。

 ten=torch.zeros(2,3)
 ten_view=ten.view(-1)
 ten_view[0]=123
 ten 

>>tensor([[123.,   0.,   0.],
           [  0.,   0.,   0.]])

 ten=torch.zeros(2,3)
 ten_flat=ten.flatten()
 ten_flat[0]=123
 ten

>>tensor([[123.,   0.,   0.],
        [  0.,   0.,   0.]])

在上述代码中,张量ten具有连续的内存分配。对ten_viewten_flat的任何更改都会反映在张量ten上。
ten=torch.zeros(2,3).transpose(0,1)
ten_flat=ten.flatten()
ten_flat[0]=123
ten

>>tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

在这种情况下,非连续的转置张量 ten 被用于 flatten()。对 ten_flat 进行的任何更改都不会反映在 ten 上。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接