LSTM 模型可视化
“看见”模型是理解深度学习最直观的方式。在 PyTorch 中,有三个层次的方法可以让你“看见”刚才搭建的 LSTM 模型:
- 结构层:看代码层级(最简单)。
- 数据层:看张量流动的维度(最实用,强烈推荐)。
- 图形层:画出网络拓扑图(最酷炫)。
以下是具体的操作方法:
方法一:直接打印 (最简单)
这是 PyTorch 自带的功能。虽然它不能显示维度,但能让你确认模型里有哪些层。
# 1. 实例化你的模型
# 假设词表大小=1000, 词向量=100, 隐藏层=128, 二分类
model = LSTMClassifier(vocab_size=1000, embedding_dim=100, hidden_dim=128, output_dim=2)
# 2. 直接 print
print(model)
你会看到这样的输出:
LSTMClassifier(
(embedding): Embedding(1000, 100, padding_idx=0)
(lstm): LSTM(100, 128, batch_first=True)
(fc): Linear(in_features=128, out_features=2, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
)
方法二:使用 torchinfo 查看维度 (最推荐 ⭐️)
这是工业界查错的神器。它能告诉你每一层的数据进去是多少维,出来是多少维,还能统计参数量。
你需要先安装:
pip install torchinfo
然后运行代码:
from torchinfo import summary
import torch
# 实例化模型
model = LSTMClassifier(vocab_size=1000, embedding_dim=100, hidden_dim=128, output_dim=2)
# 构造一个假输入 (Batch_Size=32, Seq_Len=50)
# 注意:Embedding层的输入必须是整数 (Long),所以我们要指定 dtypes
batch_size = 32
seq_len = 50
summary(model,
input_size=(batch_size, seq_len),
dtypes=[torch.long], # 告诉它是整数输入
col_names=["input_size", "output_size", "num_params"], # 显示你关心的列
verbose=1)
你会看到这样的神级表格(能够解答你关于 Squeeze 的所有疑惑):
=============================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param #
=============================================================================================
LSTMClassifier [32, 50] [32, 2] --
├─Embedding: 1-1 [32, 50] [32, 50, 100] 100,000
├─LSTM: 1-2 [32, 50, 100] [32, 50, 128] 117,760
├─Dropout: 1-3 [32, 128] [32, 128] --
├─Linear: 1-4 [32, 128] [32, 2] 258
=============================================================================================
Total params: 218,018
...
注意看 LSTM 的输出是 [32, 50, 128],而 Dropout 的输入变成了 [32, 128]。这就证明了你的代码里提取 last_hidden 的逻辑生效了!
方法三:可视化计算图 (进阶)
如果你想看到像论文里那样的“节点连接图”,可以使用 torchviz。
你需要安装:
pip install torchviz
# 注意:这通常还需要你在系统里安装 Graphviz 软件 (非Python库)
代码:
from torchviz import make_dot
# 假数据
dummy_input = torch.randint(0, 1000, (32, 50))
# 前向传播一次得到输出
y = model(dummy_input)
# 画图
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render("lstm_model_view", format="png") # 会在目录下生成一个 png 图片
这会生成一张复杂的连线图,展示数据是如何经过每一个算子(Add, Mul, Tanh)流动的。
方法四:Netron (上帝视角)
还有一个非常流行的方法是把模型导出为通用格式 (ONNX),然后用浏览器看。
- 导出模型:
torch.onnx.export(model, dummy_input, "model.onnx") - 查看:打开网站 Netron.app,把生成的
model.onnx拖进去。 - 效果:你会看到一个非常漂亮、可交互的模型结构图,点击每一层还能看到具体的权重矩阵数值。

浙公网安备 33010602011771号