深度学习(onnx转pth)

一般深度学习的开发流程是训练好模型导出pth,然后转换为onnx。

但是如果一开始只有onnx,且onnx结构不是特别复杂,可以通过onnx反向推出pytorch模型结构。

下面用Horizon提出的mixvargenet做个例子,开发环境中只找到了onnx文件。

文件下载地址:https://github.com/HorizonRobotics-Platform/ModelZoo/tree/master/MixVarGENet

下面是转换比较关键的地方:

1. 在netron查看onnx结构并写出pytorch模型代码。

2. 编写完模型可以导出为onnx,然后用onnx_tool.model_profile工具看看新onnx和原始onnx各层是否一致。

3. pytorch导出onnx时conv+bn层会合并为一个conv层,但是原始模型一般都是有bn层的,所以代码保留但注释掉(想还原bn层的参数应该是相当困难了)。

4. 将onnx模型的参数填充到pytorch模型中,如果维度没有对齐是会报错的。 

5. 最后比较一下两个模型的推理结果,即可验证整个转换是否正确。

import torch
import torch.nn as nn
import onnx
import numpy as np
import onnxruntime as ort

class SSC(nn.Module):
    def __init__(self, in_channels,hidden_channels,out_channels,groups):
        super(SSC, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels,dilation=1,groups=groups,kernel_size=3,padding=1,stride=1)
      #  self.bn1 = nn.BatchNorm2d(hidden_channels)
        self.conv2 = nn.Conv2d(hidden_channels, out_channels,dilation=1,groups=1,kernel_size=1,padding=0,stride=1)
      #  self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv1(x)
       # out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
      #  out = self.bn2(out)
        out += x
        out = self.relu(out)
        return out
    
class CSC(nn.Module):
    def __init__(self, in_channels,hidden_channels,out_channels,groups):
        super(CSC, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,dilation=1,groups=1,kernel_size=1,padding=0,stride=2)
      #  self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(in_channels, hidden_channels,dilation=1,groups=groups,kernel_size=3,padding=1,stride=2)
      #  self.bn2 = nn.BatchNorm2d(hidden_channels)
        self.conv3 = nn.Conv2d(hidden_channels, out_channels,dilation=1,groups=1,kernel_size=1,padding=0,stride=1)
      #  self.bn3 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv1(x)
      #  out = self.bn1(out)
        x = self.conv2(x)
     #   x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
     #   x = self.bn3(x)
        out += x
        out = self.relu(out)
        return out

class MixVarGeNet(nn.Module):
    def __init__(self):
        super(MixVarGeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32,dilation=1,groups=1, kernel_size=3, padding=1,stride=2)
      #  self.bn1 = nn.BatchNorm2d(32)
        self.blocks = nn.Sequential(
            SSC(32, 64, 32, 1),
            CSC(32, 128, 32, 1),
            SSC(32, 128, 32, 1),
            SSC(32, 128, 32, 1),
            CSC(32, 128, 64, 1),
            SSC(64, 256, 64, 1),
            SSC(64, 256, 64, 1),
            CSC(64, 128, 96, 4),
            SSC(96, 192, 96, 6),
            SSC(96, 192, 96, 6),
            SSC(96, 192, 96, 6),
            SSC(96, 192, 96, 6),
            SSC(96, 192, 96, 6),
            SSC(96, 192, 96, 6),
            CSC(96, 192, 160, 6),
            SSC(160, 320, 160, 10),
            SSC(160, 320, 160, 10))

        self.conv2 = nn.Conv2d(160, 1024,dilation=1,groups=1, kernel_size=1, padding=0,stride=1)
       # self.bn2 = nn.BatchNorm2d(1024)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1, padding=0)
        self.out = nn.Conv2d(1024, 1000,dilation=1,groups=1, kernel_size=1, padding=0,stride=1)
      #  self.bn3 = nn.BatchNorm2d(1000)

    def forward(self, x):
        x = self.conv1(x)
     #   x = self.bn1(x)
        x = self.blocks(x)
        x = self.conv2(x)
      #  x = self.bn2(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = self.out(x)
     #   x = self.bn3(x)
        return x

if __name__ == "__main__":

  onnx_name = 'mixvargenet_704x1280.onnx'

  x = torch.randn(1, 3, 704, 1280)
  ort_sess = ort.InferenceSession(onnx_name)
  y_onnx = torch.from_numpy(ort_sess.run(None,{"input.1":x.numpy()})[0])

  onnx_model = onnx.load(onnx_name)
  onnx_params = {tensor.name: tensor for tensor in onnx_model.graph.initializer}
  keylist = list(onnx_params.keys())

  model = MixVarGeNet()
  for id, data in enumerate(model.named_parameters()):
      name, param = data[0],data[1]
      array_data = np.frombuffer(onnx_params[keylist[id]].raw_data, dtype=np.float32)
      print(id,keylist[id],name,param.size(),array_data.size)

      onnx_param = torch.from_numpy(np.array(array_data)).view(param.size())
      param.data.copy_(onnx_param)

  y = model(x)

  allclose = torch.allclose(y,y_onnx,atol=1e-7)
  print(allclose)
posted @ 2025-03-08 16:24  Dsp Tian  阅读(352)  评论(0)    收藏  举报