如何在PyTorch中获取每层的前置节点?

3

我可以像Keras一样从PyTorch获取模型的摘要:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet = models.resnet18().to(device)

summary(resnet , (3, 224, 224))

结果如下:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 64, 56, 56]          36,864
      BatchNorm2d-16           [-1, 64, 56, 56]             128
             ReLU-17           [-1, 64, 56, 56]               0
       BasicBlock-18           [-1, 64, 56, 56]               0
           Conv2d-19          [-1, 128, 28, 28]          73,728
      BatchNorm2d-20          [-1, 128, 28, 28]             256
             ReLU-21          [-1, 128, 28, 28]               0
           Conv2d-22          [-1, 128, 28, 28]         147,456
      BatchNorm2d-23          [-1, 128, 28, 28]             256
           Conv2d-24          [-1, 128, 28, 28]           8,192
      BatchNorm2d-25          [-1, 128, 28, 28]             256
             ReLU-26          [-1, 128, 28, 28]               0
       BasicBlock-27          [-1, 128, 28, 28]               0
           Conv2d-28          [-1, 128, 28, 28]         147,456
      BatchNorm2d-29          [-1, 128, 28, 28]             256
             ReLU-30          [-1, 128, 28, 28]               0
           Conv2d-31          [-1, 128, 28, 28]         147,456
      BatchNorm2d-32          [-1, 128, 28, 28]             256
             ReLU-33          [-1, 128, 28, 28]               0
       BasicBlock-34          [-1, 128, 28, 28]               0
           Conv2d-35          [-1, 256, 14, 14]         294,912
      BatchNorm2d-36          [-1, 256, 14, 14]             512
             ReLU-37          [-1, 256, 14, 14]               0
           Conv2d-38          [-1, 256, 14, 14]         589,824
      BatchNorm2d-39          [-1, 256, 14, 14]             512
           Conv2d-40          [-1, 256, 14, 14]          32,768
      BatchNorm2d-41          [-1, 256, 14, 14]             512
             ReLU-42          [-1, 256, 14, 14]               0
       BasicBlock-43          [-1, 256, 14, 14]               0
           Conv2d-44          [-1, 256, 14, 14]         589,824
      BatchNorm2d-45          [-1, 256, 14, 14]             512
             ReLU-46          [-1, 256, 14, 14]               0
           Conv2d-47          [-1, 256, 14, 14]         589,824
      BatchNorm2d-48          [-1, 256, 14, 14]             512
             ReLU-49          [-1, 256, 14, 14]               0
       BasicBlock-50          [-1, 256, 14, 14]               0
           Conv2d-51            [-1, 512, 7, 7]       1,179,648
      BatchNorm2d-52            [-1, 512, 7, 7]           1,024
             ReLU-53            [-1, 512, 7, 7]               0
           Conv2d-54            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-55            [-1, 512, 7, 7]           1,024
           Conv2d-56            [-1, 512, 7, 7]         131,072
      BatchNorm2d-57            [-1, 512, 7, 7]           1,024
             ReLU-58            [-1, 512, 7, 7]               0
       BasicBlock-59            [-1, 512, 7, 7]               0
           Conv2d-60            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-61            [-1, 512, 7, 7]           1,024
             ReLU-62            [-1, 512, 7, 7]               0
           Conv2d-63            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-64            [-1, 512, 7, 7]           1,024
             ReLU-65            [-1, 512, 7, 7]               0
       BasicBlock-66            [-1, 512, 7, 7]               0
        AvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                 [-1, 1000]         513,000
================================================================

但在Keras中,我能够获得每层的前驱节点。

Model Summary:
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 1, 15, 27)     0                                            
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D)  (None, 8, 15, 27)     872         input_1[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D)    (None, 8, 7, 27)      0           convolution2d_1[0][0]            
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 1512)          0           maxpooling2d_1[0][0]             
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 1)             1513        flatten_1[0][0]                  
====================================================================================================

我怎么才能获取pytorch每一层的前驱节点呢?我查看了OrderDict,但没有前驱节点的信息。 在pytorch中,我该如何获取每一层的前驱节点信息呢?


一些神经网络是DAG图。例如,ResNet中,我的“前驱节点”是指一个层的上方层。 - lee YingGang
这是 Keras 中的“Connected to”部分吗? - ndrwnaguib
我不确定,但是它们的排序方式不是提供了相同的信息吗? - ndrwnaguib
Pytorch只提供每个层的信息,但不提供每个层的前任信息,而Keras可以做到。例如,InceptionV4,它的整体结构不是链式的,而是一个有向无环图。也许是因为Pytorch是动态图? - lee YingGang
不,Pytorch不会提供其拓扑结构,它的层是按顺序排列的。 - lee YingGang
显示剩余2条评论
1个回答

0

你说得对 - PyTorch使用动态计算图,因此本身没有子代/祖先的概念。例如,Inception3模型是通过声明一堆子模块然后通过一个长的手工编码方法来执行的,该方法只是以某种方式按某种顺序使用它们。

这允许使用任意流程控制,在这种情况下,您很难告诉哪个层是给定层的子代 - 它取决于数据输入。

但是,对于某些特殊情况,这是可能的。例如,VGG模型使用{{link4:nn.Sequantial}}构建,它是应用于其输入的模块列表。如果您有一个这样的模型

model = nn.Sequential(nn.Linear(30, 40), nn.Linear(40, 20), nn.Linear(20, 30))

你知道第二个Linear层(model[1])的祖先是model[0],它的子代是model[2]

在我这个外行人的眼里,Inception模型似乎可以大部分使用nn.Sequantial容器来实现,这将给您期望的功能。话虽如此,在(至少不在torchvision模型库中),它们并没有这样做,因此您无法通过其他方式手动获得。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接