3.2export_onnx

📦 1. 模型定义与导出代码

import torch
import torch.nn as nn
import torch.onnx
import onnxsim
import onnx

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1   = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.bn1     = nn.BatchNorm2d(num_features=16)
        self.act1    = nn.ReLU()
        self.conv2   = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=5, padding=2)
        self.bn2     = nn.BatchNorm2d(num_features=64)
        self.act2    = nn.ReLU()
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head    = nn.Linear(in_features=64, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)

        # flatten: B×C×H×W → B×C×L(L=H×W)
        x = torch.flatten(x, 2, 3)

        # 平均池化:B×C×L → B×C×1
        x = self.avgpool(x)

        # 再次 flatten:B×C×1 → B×C
        x = torch.flatten(x, 1)

        # 全连接层分类:B×C → B×10
        x = self.head(x)
        return x

📤 2. 导出为 ONNX 并简化

def export_norm_onnx():
    input   = torch.rand(1, 3, 64, 64)  # 输入:B×3×64×64
    model   = Model()
    file    = "./sample-reshape.onnx"

    # 导出 ONNX 模型
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = file,
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 15
    )
    print("Finished normal onnx export")

    # 检查模型结构合法性
    model_onnx = onnx.load(file)
    onnx.checker.check_model(model_onnx)

    # 使用 onnx-simplifier 进行图结构简化
    print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...")
    model_onnx, check = onnxsim.simplify(model_onnx)
    assert check, "assert check failed"
    onnx.save(model_onnx, file)

🧠 小提示:为什么 flatten 会生成多个节点?

x = torch.flatten(x, 2, 3)

ONNX 中不支持 flatten(x, start_dim=2) 这样的高维展开直接表示,因此 PyTorch 导出时会转换为:

  • Shape:获取张量形状
  • Slice:提取要 flatten 的维度
  • Concat:拼接新 shape
  • Reshape:完成 flatten 动作

使用 onnxsim 简化后,这些操作通常会被合并为一个简单的 FlattenReshape


✅ 3. 主函数执行导出流程

if __name__ == "__main__":
    export_norm_onnx()

🔧 代码结构整体说明:

🔹 模型结构(Model 类):

x -> conv1 -> bn1 -> relu1  
  -> conv2 -> bn2 -> relu2  
  -> flatten -> avgpool -> flatten -> linear -> output

其中重点在于:

🔸 第一段 flatten:

x = torch.flatten(x, 2, 3)  # B, C, H, W -> B, C, L

这个操作会导致导出的 ONNX 图中生成:

  • Shape
  • Slice
  • Concat
  • Reshape

等一系列辅助节点。为什么?


🧠 为什么 flatten(x, 2, 3) 会变成这么多 ONNX 节点?

PyTorch 的 torch.flatten(x, 2, 3) 表示:

  • x 从第 2 维(H)到第 3 维(W)展平为一维
  • 举个例子:输入 x[B, C, H, W],flatten 后变成 [B, C, H*W]

但是在 ONNX 中:

  • ONNX 不支持 “动态切片 + flatten” 作为单一原始操作
  • 所以需要分解为多个步骤来实现:

1. Shape:先获取 x 的形状

2. Slice:抽取你需要的维度值(这里是 HW

3. Concat:拼接出新 shape,例如 [B, C, H*W]

4. Reshape:应用这个新 shape

这就是你看到的:

Shape -> Slice -> Slice -> Mul -> Concat -> Reshape

的由来。


🔍 为什么导出前后图不一样?

你有两个版本:

✅ 原始导出图:

  • 有上述所有细化节点(Slice/Shape/Reshape 等)
  • 这对于 动态输入尺寸 很重要,但会让图复杂

✅ 简化后的 ONNX(使用 onnxsim.simplify):

  • 会自动识别这部分是一个 flatten 动作
  • 用更简洁的方式重新表达(甚至直接用一个 Flatten 节点)

这是为什么你写了:

# onnx中其实会有一些constant value,以及不需要计算图跟踪的节点
# 大家可以一起从netron中看看这些节点都在干什么

🔍 平铺流程:flatten + avgpool + flatten + fc

你原始的网络有这几步转换:

步骤 输入维度 输出维度 说明
flatten(x, 2, 3) [B,C,H,W] [B,C,L] H × W 展平为 L
AdaptiveAvgPool1d(1) [B,C,L] [B,C,1] 类似全局平均池化
flatten(x, 1) [B,C,1] [B,C] 去掉最后一维
Linear [B,C] [B,10] 最终全连接层分类

🧪 建议你动手做以下实验理解更深:

  1. 注释掉 onnxsim.simplify(),用 Netron 打开 .onnx 文件,看看 flatten 变成了哪些低层操作?
  2. 然后再运行一次 simplify,看看有没有把它们合并成一个 Flatten 或更简洁的结构?
  3. torch.flatten(x, 2, 3) 换成 .view(b, c, -1).reshape(...),看看导出的结构是否更简洁?

✅ 总结重点

内容
flatten 操作为什么变复杂? 因为 ONNX 中 flatten 只支持从第 1 个维度开始,如果你指定的是 2~3,会生成 shape/slice/reshape
onnxsim.simplify 作用? 自动识别复杂逻辑并简化(合并 slice、reshape 等)
推荐做法? 导出前先理解动态维度怎么计算,导出后建议简化以减小模型体积、提升兼容性
哪些操作最容易生成冗余图? flatten、transpose、reshape、permute、expand 等涉及动态 shape 的操作
posted @ 2025-07-30 21:20  小小怪历险记  阅读(70)  评论(0)    收藏  举报