深度学习(onnx合并)

多个onnx可以合并为一个onnx,这样在c++调用的时候会方便一些。

如果有原始pytorch代码和模型参数,可以在导出时合并。

如果只有onnx文件,可以用下面的方法合并,每个模型各自是各自的输入输出。

import onnx
import torch
import torch.nn as nn
import onnxruntime as ort

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

torch_model1 = SimpleModel()
torch_model2 = SimpleModel()

torch.onnx.export(torch_model1, torch.randn(1, 10), "model1.onnx", verbose=False)
torch.onnx.export(torch_model2, torch.randn(1, 10), "model2.onnx", verbose=False)

onnx_model1 = onnx.load("model1.onnx")
onnx_model2 = onnx.load("model2.onnx")

onnx_model1 = onnx.compose.add_prefix(onnx_model1, prefix="model1_")
onnx_model2 = onnx.compose.add_prefix(onnx_model2, prefix="model2_")

onnx_model1.graph.input.extend(onnx_model2.graph.input)
onnx_model1.graph.node.extend(onnx_model2.graph.node)
onnx_model1.graph.initializer.extend(onnx_model2.graph.initializer)
onnx_model1.graph.output.extend(onnx_model2.graph.output)

onnx.save(onnx_model1, "fus_model.onnx")

#推理对比结果
session = ort.InferenceSession("fus_model.onnx")
input_name1 = session.get_inputs()[0].name  
input_name2 = session.get_inputs()[1].name

x = torch.randn(1, 10)
outputs = session.run(None, {input_name1: x.numpy() ,input_name2: x.numpy()})  

print(torch_model1(x).detach().numpy(),outputs[0])
print(torch_model2(x).detach().numpy(),outputs[1])
posted @ 2025-04-19 22:25  Dsp Tian  阅读(139)  评论(0)    收藏  举报