pytorch可视化示例
conda install graphviz python-graphviz
pip install hiddenlayer
import torch
import torchvision.models
import hiddenlayer as hl
VGG16 with BatchNorm
model = torchvision.models.vgg16()
Build HiddenLayer graph
Jupyter Notebook renders it automatically
hl.build_graph(model, torch.zeros([1, 3, 224, 224]))
AlexNet
model = torchvision.models.alexnet()
Build HiddenLayer graph
hl_graph = hl.build_graph(model, torch.zeros([1, 3, 224, 224]))
Use a different color theme
hl_graph.theme = hl.graph.THEMES["blue"].copy() # Two options: basic and blue
hl_graph
Resnet101
model = torchvision.models.resnet101()
Rather than using the default transforms, build custom ones to group
nodes of residual and bottleneck blocks.
transforms = [
# Fold Conv, BN, RELU layers into one
hl.transforms.Fold("Conv > BatchNorm > Relu", "ConvBnRelu"),
# Fold Conv, BN layers together
hl.transforms.Fold("Conv > BatchNorm", "ConvBn"),
# Fold bottleneck blocks
hl.transforms.Fold("""
((ConvBnRelu > ConvBnRelu > ConvBn) | ConvBn) > Add > Relu
""", "BottleneckBlock", "Bottleneck Block"),
# Fold residual blocks
hl.transforms.Fold("""ConvBnRelu > ConvBnRelu > ConvBn > Add > Relu""",
"ResBlock", "Residual Block"),
# Fold repeated blocks
hl.transforms.FoldDuplicates(),
]
Display graph using the transforms above
hl.build_graph(model, torch.zeros([1, 3, 224, 224]), transforms=transforms)

浙公网安备 33010602011771号