pytorch模型转trt部署

pytorch 转onnx

首先加载pytorch模型

# load model
import torch
def load_model(ckpt)
    # build model
    model = build_model()   # depending on your own model build function
    # load chpt
    checkpoint = torch.load(ckpt, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint["model_state"])
    return model

使用torch.onnx将pytorch 模型转为onnx

def export_onnx(model, onnx_name, batch_size):
    x, y = height, width
    img = torch.randn((batch_size, 3, x, y)).cuda()
    torch.onnx.export(model,
                      img,
                      onnx_name,
                      export_params=True,
                      opset_version=11,
                      input_names=["input"],
                      output_names=["output"],
                      do_constant_folding=True,
                      verbose=True
    )

onnx 转 trt

首先要安装tensorrt, 安装教程可以参考link,之后可以选择以下两种方式进行转换,1.是用trtexec命令 2.用python脚本转

  1. trtexec命令
 trtexec --onnx=path/to/onnx --saveEngine=path/to/save/trt --explicitBatch --fp16 --workspace=15000

如果提示trtexec command not found, 找到你的tensorrt安装目录,例如/usr/local/tensorrt, 将上述中的trtexec替换为/usr/local/tensorrt/bin/trtexec,如果嫌麻烦的话可以在~/.bashrc
中添加下边一句

alias trtexec="/usr/local/tensorrt/bin/trtexec"

保存退出然后source ~/.bashrc就可以使用trtexec命令了

  1. python脚本

TRT_LOGGER = trt.Logger(trt.Logger.INFO)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
def get_engine(onnx_file_path, engine_file_path, using_half):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
    def build_engine():
        device = torch.device('cuda:{}'.format(0))
        """Takes an ONNX file and creates a TensorRT engine to run inference with"""
        with trt.Builder(TRT_LOGGER) as builder, \
                builder.create_network(EXPLICIT_BATCH) as network, \
                trt.OnnxParser(network, TRT_LOGGER) as parser:

            config = builder.create_builder_config()
            config.max_workspace_size = 1 << 30
            if using_half:
                config.set_flag(trt.BuilderFlag.FP16)

            # Parse model file
            if not os.path.exists(onnx_file_path):
                print('ONNX file {} not found, please  first to generate it.'.format(onnx_file_path))
                exit(0)
            with open(onnx_file_path, 'rb') as model:
                print('Beginning ONNX file parsing')
                parser.parse(model.read())
            with torch.cuda.device(device):
                engine = builder.build_engine(network, config)
            assert engine is not None, 'Failed to create TensorRT engine'
            with open(engine_file_path, "wb") as f:
                f.write(engine.serialize())
            return engine

    if os.path.exists(engine_file_path):
        # If a serialized engine exists, use it instead of building an engine.
        print("Reading engine from file {}".format(engine_file_path))
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(f.read())
    else:
        return build_engine()


if __name__ == '__main__':
    batch_size = 1  # only works for TRT. perf reported by torch is working on non-batched data.
    using_half = True
    model_name = 'your_model_name'
    model_path = 'path/to/pth'
    onnx_path = '{name}.onnx'.format(name=model_name)

    with torch.no_grad():
        model = load_model(model_path)
        export_onnx(model, onnx_path, batch_size)
        engine = get_engine(onnx_path,
                            '{name}.trt'.format(name=model_name),
                            using_half)


加速前处理一张图片大约50ms,加速后的推理速度位10ms

参考: pytorch模型转TensorRT模型部署

posted @ 2022-05-06 15:31  老张哈哈哈  阅读(2269)  评论(0编辑  收藏  举报