.flatten()
和.view(-1)
都可以在PyTorch中对张量进行扁平化操作。它们之间的区别是什么?
.flatten()
是否会复制张量数据?.view(-1)
是否更快?- 是否存在某种情况
.flatten()
无法使用?
torch.flatten()
会导致.reshape()
,而.reshape()
和.view()
之间的差异是:
[...]
torch.reshape
可能返回原始张量的副本或视图。您不能指望它返回视图或副本。另一个区别是,reshape()可以操作连续和非连续张量,而view()只能操作连续张量。关于连续的含义,请参见此处
背景:
社区要求有一个flatten
函数,经过Issue #7743后,该功能在PR #8578中实现。
您可以在此处看到flatten的实现,其中可以看到在return
行中调用了.reshape()
。
flatten
只是1一种常见的使用方式的view
的简便别名。view
逻辑 |
| --- | --- |
| flatten()
| view(-1)
|
| flatten(start, end)
| view(*t.shape[:start], -1, *t.shape[end+1:])
|
| squeeze()
| view(*[s for s in t.shape if s != 1])
|
| unsqueeze(i)
| view(*t.shape[:i-1], 1, *t.shape[i:])
|flatten
允许您压缩特定连续维度的子集,使用start_dim
和end_dim
参数。
reshape
是表面等效的。.view()
只适用于连续数据,而.flatten()
适用于 连续 和 非连续数据。像transpose这样生成非连续数据的函数,可以通过.flatten()
进行操作,但不能使用.view()
。.view()
和.flatten()
都不会复制数据。然而,在非连续数据的情况下,.flatten()
首先将数据复制到连续内存中,然后更改维度。新张量上的任何更改都不会影响原始张量。 ten=torch.zeros(2,3)
ten_view=ten.view(-1)
ten_view[0]=123
ten
>>tensor([[123., 0., 0.],
[ 0., 0., 0.]])
ten=torch.zeros(2,3)
ten_flat=ten.flatten()
ten_flat[0]=123
ten
>>tensor([[123., 0., 0.],
[ 0., 0., 0.]])
ten=torch.zeros(2,3).transpose(0,1)
ten_flat=ten.flatten()
ten_flat[0]=123
ten
>>tensor([[0., 0.],
[0., 0.],
[0., 0.]])
.flatten()
和.flatten()
是相同的,但是.flatten()
可以让你传递start_dim
和end_dim
,以实现更复杂的行为。例如,torch.ones(10, 4, 5, 6).flatten(start_dim=1, end_dim=2)
将返回形状为(10, 20, 6)
的张量。 - adeelh