pytorch中自定义onnx新算子并导出为onnx

import torch
from torch.autograd import Function
import torch.onnx

# Step 1: Define custom PyTorch operator
class MyCustomOp(Function):
    @staticmethod
    def forward(ctx, input):
        return input + 1

    @staticmethod
    def symbolic(g, input):
        return g.op("CustomAddOne", input)#注意此处的input参数要和后面trt中的插件层一样

def custom_add_one(input):
    return MyCustomOp.apply(input)

# Step 2: Register custom ONNX operator
def custom_add_one_symbolic(g, input):
    return g.op("CustomAddOne", input)

torch.onnx.register_custom_op_symbolic("::custom_add_one", custom_add_one_symbolic, 9)

# Step 3: Export to ONNX
class MyModel(torch.nn.Module):
    def forward(self, x):
        return custom_add_one(x)

model = MyModel()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "custom_model.onnx", opset_version=9, custom_opsets={"": 9})

print("ONNX model with custom operator exported successfully.")

 

posted @ 2024-07-28 17:46  海_纳百川  阅读(500)  评论(0)    收藏  举报
本站总访问量