3.2export_onnx
📦 1. 模型定义与导出代码
import torch
import torch.nn as nn
import torch.onnx
import onnxsim
import onnx
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(num_features=16)
self.act1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=5, padding=2)
self.bn2 = nn.BatchNorm2d(num_features=64)
self.act2 = nn.ReLU()
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(in_features=64, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.act2(x)
# flatten: B×C×H×W → B×C×L(L=H×W)
x = torch.flatten(x, 2, 3)
# 平均池化:B×C×L → B×C×1
x = self.avgpool(x)
# 再次 flatten:B×C×1 → B×C
x = torch.flatten(x, 1)
# 全连接层分类:B×C → B×10
x = self.head(x)
return x
📤 2. 导出为 ONNX 并简化
def export_norm_onnx():
input = torch.rand(1, 3, 64, 64) # 输入:B×3×64×64
model = Model()
file = "./sample-reshape.onnx"
# 导出 ONNX 模型
torch.onnx.export(
model = model,
args = (input,),
f = file,
input_names = ["input0"],
output_names = ["output0"],
opset_version = 15
)
print("Finished normal onnx export")
# 检查模型结构合法性
model_onnx = onnx.load(file)
onnx.checker.check_model(model_onnx)
# 使用 onnx-simplifier 进行图结构简化
print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...")
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, "assert check failed"
onnx.save(model_onnx, file)
🧠 小提示:为什么 flatten 会生成多个节点?
x = torch.flatten(x, 2, 3)
ONNX 中不支持 flatten(x, start_dim=2) 这样的高维展开直接表示,因此 PyTorch 导出时会转换为:
Shape:获取张量形状Slice:提取要 flatten 的维度Concat:拼接新 shapeReshape:完成 flatten 动作
使用 onnxsim 简化后,这些操作通常会被合并为一个简单的 Flatten 或 Reshape。
✅ 3. 主函数执行导出流程
if __name__ == "__main__":
export_norm_onnx()
🔧 代码结构整体说明:
🔹 模型结构(Model 类):
x -> conv1 -> bn1 -> relu1
-> conv2 -> bn2 -> relu2
-> flatten -> avgpool -> flatten -> linear -> output
其中重点在于:
🔸 第一段 flatten:
x = torch.flatten(x, 2, 3) # B, C, H, W -> B, C, L
这个操作会导致导出的 ONNX 图中生成:
ShapeSliceConcatReshape
等一系列辅助节点。为什么?
🧠 为什么 flatten(x, 2, 3) 会变成这么多 ONNX 节点?
PyTorch 的 torch.flatten(x, 2, 3) 表示:
- 把
x从第 2 维(H)到第 3 维(W)展平为一维 - 举个例子:输入
x是[B, C, H, W],flatten 后变成[B, C, H*W]
但是在 ONNX 中:
- ONNX 不支持 “动态切片 + flatten” 作为单一原始操作
- 所以需要分解为多个步骤来实现:
1. Shape:先获取 x 的形状
2. Slice:抽取你需要的维度值(这里是 H 和 W)
3. Concat:拼接出新 shape,例如 [B, C, H*W]
4. Reshape:应用这个新 shape
这就是你看到的:
Shape -> Slice -> Slice -> Mul -> Concat -> Reshape
的由来。
🔍 为什么导出前后图不一样?
你有两个版本:
✅ 原始导出图:
- 有上述所有细化节点(Slice/Shape/Reshape 等)
- 这对于 动态输入尺寸 很重要,但会让图复杂
✅ 简化后的 ONNX(使用 onnxsim.simplify):
- 会自动识别这部分是一个 flatten 动作
- 用更简洁的方式重新表达(甚至直接用一个
Flatten节点)
这是为什么你写了:
# onnx中其实会有一些constant value,以及不需要计算图跟踪的节点
# 大家可以一起从netron中看看这些节点都在干什么
🔍 平铺流程:flatten + avgpool + flatten + fc
你原始的网络有这几步转换:
| 步骤 | 输入维度 | 输出维度 | 说明 |
|---|---|---|---|
flatten(x, 2, 3) |
[B,C,H,W] |
[B,C,L] |
H × W 展平为 L |
AdaptiveAvgPool1d(1) |
[B,C,L] |
[B,C,1] |
类似全局平均池化 |
flatten(x, 1) |
[B,C,1] |
[B,C] |
去掉最后一维 |
Linear |
[B,C] |
[B,10] |
最终全连接层分类 |
🧪 建议你动手做以下实验理解更深:
- 注释掉
onnxsim.simplify(),用 Netron 打开.onnx文件,看看flatten变成了哪些低层操作? - 然后再运行一次
simplify,看看有没有把它们合并成一个Flatten或更简洁的结构? - 把
torch.flatten(x, 2, 3)换成.view(b, c, -1)或.reshape(...),看看导出的结构是否更简洁?
✅ 总结重点
| 项 | 内容 |
|---|---|
| flatten 操作为什么变复杂? | 因为 ONNX 中 flatten 只支持从第 1 个维度开始,如果你指定的是 2~3,会生成 shape/slice/reshape |
onnxsim.simplify 作用? |
自动识别复杂逻辑并简化(合并 slice、reshape 等) |
| 推荐做法? | 导出前先理解动态维度怎么计算,导出后建议简化以减小模型体积、提升兼容性 |
| 哪些操作最容易生成冗余图? | flatten、transpose、reshape、permute、expand 等涉及动态 shape 的操作 |

浙公网安备 33010602011771号