可视化 PyTorch 模型
Setup
pip install torch torchvision -i https://download.pytorch.org/whl/cu126
定义模型和数据:
from torchvision.models import resnet18
import torch as th
x = th.randn(1, 3, 224, 224)
model = resnet18()
torchview
pip install torchview
from torchview import draw_graph
model_graph = draw_graph(
model,
input_size=x.shape,
expand_nested=True,
save_graph=True,
filename="torchview",
directory="."
)
model_graph.visual_graph

netron
pip install onnx netron
th.onnx.export(model, x, "resnet18.onnx", verbose=True)
netron resnet18.onnx

浙公网安备 33010602011771号