深度学习(onnx合并)
多个onnx可以合并为一个onnx,这样在c++调用的时候会方便一些。
如果有原始pytorch代码和模型参数,可以在导出时合并。
如果只有onnx文件,可以用下面的方法合并,每个模型各自是各自的输入输出。
import onnx import torch import torch.nn as nn import onnxruntime as ort class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(10, 5) def forward(self, x): return self.linear(x) torch_model1 = SimpleModel() torch_model2 = SimpleModel() torch.onnx.export(torch_model1, torch.randn(1, 10), "model1.onnx", verbose=False) torch.onnx.export(torch_model2, torch.randn(1, 10), "model2.onnx", verbose=False) onnx_model1 = onnx.load("model1.onnx") onnx_model2 = onnx.load("model2.onnx") onnx_model1 = onnx.compose.add_prefix(onnx_model1, prefix="model1_") onnx_model2 = onnx.compose.add_prefix(onnx_model2, prefix="model2_") onnx_model1.graph.input.extend(onnx_model2.graph.input) onnx_model1.graph.node.extend(onnx_model2.graph.node) onnx_model1.graph.initializer.extend(onnx_model2.graph.initializer) onnx_model1.graph.output.extend(onnx_model2.graph.output) onnx.save(onnx_model1, "fus_model.onnx") #推理对比结果 session = ort.InferenceSession("fus_model.onnx") input_name1 = session.get_inputs()[0].name input_name2 = session.get_inputs()[1].name x = torch.randn(1, 10) outputs = session.run(None, {input_name1: x.numpy() ,input_name2: x.numpy()}) print(torch_model1(x).detach().numpy(),outputs[0]) print(torch_model2(x).detach().numpy(),outputs[1])