可视化 PyTorch 模型

Setup

pip install torch torchvision -i https://download.pytorch.org/whl/cu126

定义模型和数据:

from torchvision.models import resnet18
import torch as th

x = th.randn(1, 3, 224, 224)
model = resnet18()

torchview

pip install torchview
from torchview import draw_graph

model_graph = draw_graph(
    model,
    input_size=x.shape,
    expand_nested=True,
    save_graph=True,
    filename="torchview",
    directory="."
)

model_graph.visual_graph

image

netron

pip install onnx netron
th.onnx.export(model, x, "resnet18.onnx", verbose=True)
netron resnet18.onnx

image

参考:PyTorch 模型网络可视化画图工具合集 | Medium

posted @ 2025-06-17 19:06  Undefined443  阅读(22)  评论(0)    收藏  举报