如何在Pytorch中可视化神经网络?

108
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

我想从pytorch模型中可视化resnet。我该如何做?我尝试使用torchviz,但出现了错误:

'ResNet' object has no attribute 'grad_fn'

你使用的是哪个版本的PyTorch? - dennlinger
最新的主分支代码 - raaj
PyTorch对TensorBoard的支持怎么样? - Charlie Parker
6个回答

112

这里有三个使用不同工具的图形可视化。

为了生成示例可视化,我将使用一个简单的RNN来执行情感分析,该模型是从一个在线教程中获取的:

class RNN(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

        super().__init__()
        self.embedding  = nn.Embedding(input_dim, embedding_dim)
        self.rnn        = nn.RNN(embedding_dim, hidden_dim)
        self.fc         = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):

        embedding       = self.embedding(text)
        output, hidden  = self.rnn(embedding)

        return self.fc(hidden.squeeze(0))

如果你使用print()函数输出模型,这里是结果。

RNN(
  (embedding): Embedding(25002, 100)
  (rnn): RNN(100, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

以下是三种不同可视化工具的结果。

对于所有这些工具,您需要有可以通过模型的 forward() 方法的虚拟输入。获取此输入的简单方法是从您的数据加载器中检索一个批次,就像这样:

batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().

Torchviz

https://github.com/szagoruyko/pytorchviz

我认为这个工具是使用反向传播生成其图形的,因此所有框都使用 PyTorch 组件进行反向传播。

from torchviz import make_dot

make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

这个工具生成以下输出文件:

torchviz output

这是唯一一个清楚地提到我模型中三个层(embeddingrnnfc)的输出。操作符名称来自于反向传播,因此有些很难理解。

HiddenLayer

https://github.com/waleedka/hiddenlayer

我认为这个工具使用了前向传播。

import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')

这里是输出结果。我喜欢蓝色的色调。

hiddenlayer output

我发现输出结果太过详细,会掩盖我的架构细节。例如,为什么会这么多次提到unsqueeze

Netron

https://github.com/lutzroeder/netron

Netron是一款适用于Mac、Windows和Linux平台的桌面应用程序。它需要先将模型导出到ONNX格式,然后读取ONNX文件并呈现出来。然后可以选择将模型导出为图像文件。

input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)

以下是该应用程序中模型的外观。我认为这个工具相当不错,您可以缩放和平移,还可以深入图层和操作器。我发现唯一的负面是它只支持纵向布局。

Netron screenshot


3
Netron也支持水平布局(请参见菜单)。 - paulgavrikov
我无法让GATConv神经网络正常工作? 我在get_var_name中遇到了一个异常:'NoneType'对象没有属性'size'。 - Thomas Gak-Deluen
Netron确实有一个选项“显示水平”。对我非常有效。 - David Jung
4
抱歉,batch.text是什么? - Prakhar Sharma
我认为这是批量数据,只需放入一个虚拟批量数据即可。 - undefined

41

make_dot 函数期望的是一个变量(即带有 grad_fn 的张量),而不是模型本身。
请尝试:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module

2
如何将图像保存为文件? - Charlie Parker
2
这展示了反向传播的过程,但我可以知道如何查看正向传播吗? - Luk Aron
@LukAron,你的前向传递和后向传递有什么不同之处?后向传递是由前向传递(以及梯度链规则)定义的。 - Shai
1
我认为 requires_grad=False 是不正确的。Torchviz 需要将其设置为 True 才能够进行反向传播和跟踪网络图结构。 - Thariq Nugrohotomo

22
这可能是一个迟来的答案。但是,特别是在开发了__torch_function__后,可以获得更好的可视化效果。您可以在这里尝试我的项目torchview
对于您的resnet50示例,请查看colab笔记本,在 here 我演示了resnet18模型的可视化。Resnet18的图像由以下代码生成。
import torchvision
from torchview import draw_graph

model_graph = draw_graph(resnet18(), input_size=(1,3,224,224), expand_nested=True)
model_graph.visual_graph

Resnet by Torchview

它还接受各种输出/输入类型(例如列表、字典)


1
虽然这个链接可能回答了问题,但最好在此处包含答案的基本部分并提供参考链接。如果链接页面更改,仅有链接的答案可能会失效。-【来自审查】 - isaactfa
1
@isaactfa 我试图添加显示结果的图片,但我没有足够高的徽章来添加图片。您有没有一种解决方案可以显示我提供链接的图片? - Mert Kurttutan
代码可以运行,但没有出现图形。 - ojunk
无法访问Colab笔记本。 - ojunk
@ojunk 我刚刚更新了笔记本的链接,现在应该可以使用了。我还检查了代码,对我来说运行良好。 - Mert Kurttutan
1
它在Google Colab上运行良好。如果您想在VSCode上使其工作并将其保存为PNG或SVG,请使用model_graph.resize_graph(scale=5.0) # scale as per the view model_graph.visual_graph.render(format='svg') - Kevin Patel

16

3
如何将图表保存为图片? - Charlie Parker
4
from graphviz import Source; 表示导入Graphviz库中的Source功能。model_arch = make_dot(...); 是一个函数调用语句,其中 make_dot() 函数的输入参数没有给出。需要在上下文中找到该函数并获取其详细说明。Source(model_arch).render(filepath);model_arch 渲染为图像文件,并保存到指定路径 filepath - Nagabhushan S N
1
这展示了我们进行反向传播时会发生什么。但我可以知道如何查看前向传播吗? - Luk Aron
@CharlieParker https://github.com/szagoruyko/pytorchviz/issues/24 - Marine Galantin

13

如果您想保存图像,可以使用torchviz来完成:

# http://www.bnikolic.co.uk/blog/pytorch-detach.html

import torch
from torchviz import make_dot

x=torch.ones(10, requires_grad=True)
weights = {'x':x}

y=x**2
z=x**3
r=(y+z).sum()

make_dot(r).render("attached", format="png")

您获得的图像截图:

在此输入图像描述

来源:http://www.bnikolic.co.uk/blog/pytorch-detach.html


这展示了反向传播的过程,但我可以知道如何查看正向传播吗? - Luk Aron
@LukAron 那基本上就是前向传播...它只是那些操作的反向版本。 - Charlie Parker
+1 这是唯一一个对我的torchviz实际有效的答案。重要的部分是 requires_grad=True - Thariq Nugrohotomo

4
如果我可以不要脸地自荐一下,我写了一个包叫做TorchLens,它可以用一行代码将PyTorch模型图可视化(它应该适用于任何任意的PyTorch模型,但如果对你的模型无效,请告诉我)。

它的效果很好!但是我尝试了torchvision.models.detection.fcos_resnet50_fpn(),生成图像的时间太长了(在CPU上)。有没有办法更快地生成呢? - undefined
感谢您的反馈,这是我改进包装的方式:] 如果您在过程中不保存任何激活状态,可视化会稍微快一些,但对于非常复杂的模型来说,总是需要一些时间(但我一直在寻找加快速度的方法)。与此同时,如果您想要特定模型的可视化效果,可以在以下链接中找到,包括训练和评估模式:https://drive.google.com/drive/u/0/folders/1uq3VNqpvaacWMKiNb1_oLR2uzSmJr9ms - undefined

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接