pytorch可视化示例

conda install graphviz python-graphviz

pip install hiddenlayer

import torch
import torchvision.models
import hiddenlayer as hl

VGG16 with BatchNorm

model = torchvision.models.vgg16()

Build HiddenLayer graph

Jupyter Notebook renders it automatically

hl.build_graph(model, torch.zeros([1, 3, 224, 224]))

AlexNet

model = torchvision.models.alexnet()

Build HiddenLayer graph

hl_graph = hl.build_graph(model, torch.zeros([1, 3, 224, 224]))

Use a different color theme

hl_graph.theme = hl.graph.THEMES["blue"].copy() # Two options: basic and blue
hl_graph

Resnet101

model = torchvision.models.resnet101()

Rather than using the default transforms, build custom ones to group

nodes of residual and bottleneck blocks.

transforms = [
# Fold Conv, BN, RELU layers into one
hl.transforms.Fold("Conv > BatchNorm > Relu", "ConvBnRelu"),
# Fold Conv, BN layers together
hl.transforms.Fold("Conv > BatchNorm", "ConvBn"),
# Fold bottleneck blocks
hl.transforms.Fold("""
((ConvBnRelu > ConvBnRelu > ConvBn) | ConvBn) > Add > Relu
""", "BottleneckBlock", "Bottleneck Block"),
# Fold residual blocks
hl.transforms.Fold("""ConvBnRelu > ConvBnRelu > ConvBn > Add > Relu""",
"ResBlock", "Residual Block"),
# Fold repeated blocks
hl.transforms.FoldDuplicates(),
]

Display graph using the transforms above

hl.build_graph(model, torch.zeros([1, 3, 224, 224]), transforms=transforms)

posted @ 2021-01-24 17:33  xinkevinzhang  阅读(337)  评论(0)    收藏  举报