torch.flatten() 和 nn.Flatten() 的区别

12
< p >torch.flatten()和torch.nn.Flatten()有什么区别?< /p >
2个回答

16

在PyTorch中,Flattening有三种形式:

  • 作为张量方法(oop风格)torch.Tensor.flatten 直接应用于一个张量:x.flatten()

  • 作为函数(functional形式)torch.flatten 应用方式为:torch.flatten(x)

  • 作为模块(nn.Module层) nn.Flatten()。通常在模型定义中使用。

所有这三种形式都相同且共享相同的实现,唯一的区别是nn.Flatten默认将start_dim设置为1以避免展平第一个轴(通常是批量轴)。而其他两种则从axis = 0axis = -1展平整个张量,如果没有给出参数。


仍然是真的吗?在C代码中,flatten使用了reshape吗? - prosti
1
找到了这个:torch/csrc/utils/tensor_flatten.h。看起来它使用了view,这是一种重塑操作! - Ivan

4
您可以将torch.flatten()的工作视为对张量进行简单的展平操作,没有任何附加条件。您提供一个张量,它会展平并返回它。就是这样。
相反,nn.Flatten()要复杂得多(即它是神经网络层)。作为面向对象编程,它继承自nn.Module,虽然在展平张量时forward()方法中内部使用纯张量.flatten()操作。您可以将其视为对torch.flatten()的一种语法糖。

重要区别:值得注意的区别是,torch.flatten()始终返回一个1D张量作为结果,前提是输入至少是1D或更大,而nn.Flatten()始终返回一个2D张量,前提是输入至少是2D或更大(对于1D张量作为输入,它会抛出IndexError异常)。


比较:

  • torch.flatten()是一个API,而nn.Flatten()是一个神经网络层。

  • torch.flatten()是一个Python函数,而nn.Flatten()是一个Python类。

  • 由于上述原因,nn.Flatten()带有许多方法和属性

  • torch.flatten()可以在实际应用中使用(例如,进行简单的张量操作),而nn.Flatten()预期将作为nn.Sequential()块中的一层来使用。

  • torch.flatten()没有关于计算图的信息,除非它被嵌入到其他具有tensor.requires_grad标志设置为True的图形感知块中,而nn.Flatten()始终被autograd跟踪。

  • torch.flatten()无法接受和处理(例如,线性/conv1D)层作为输入,而nn.Flatten()主要用于处理这些神经网络层。

  • torch.flatten()nn.Flatten()都返回输入张量的视图。因此,对结果的任何修改也会影响输入张量。(请参见下面的代码)


代码演示:

# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1)   # 3D tensor

使用torch.flatten()进行扁平化:

In [113]: t1flat = torch.flatten(t1)

In [114]: t1flat
Out[114]: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

# modification to the flattened tensor    
In [115]: t1flat[-1] = -1

# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]: 
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, -1]])

使用 nn.Flatten() 进行扁平化:

In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)

# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]: 
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27],
        [28, 29, 30, 31, 32, 33, 34, 35]])

# modification to the result
In [126]: t3flat[-1, -1] = -1

# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]: 
tensor([[[12, 13, 14, 15],
         [16, 17, 18, 19]],

        [[20, 21, 22, 23],
         [24, 25, 26, 27]],

        [[28, 29, 30, 31],
         [32, 33, 34, -1]]])

小贴士: torch.flatten()nn.Flatten() 以及其 同类 nn.Unflatten() 的前身,因为它从一开始就存在。然后,出现了一个合法的 nn.Flatten() 的用例,因为这是几乎所有 ConvNets(在 softmax 或其他地方之前)的常见要求。因此,它在 PR #22245 中稍后被添加。

还有最近的 建议在 ResNets 中使用 nn.Flatten() 进行模型手术


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接