TensorRT模型查看input shape

Requirement

以下代码经测试,在TensoRT 10.1版本中可以正常运行。
pip安装信息如下:

tensorrt-cu11                            10.1.0
tensorrt-cu11-bindings                   10.1.0
tensorrt-cu11-libs                       10.1.0
tensorrt-lean-cu11                       10.1.0
tensorrt-lean-cu11-bindings              10.1.0
tensorrt-lean-cu11-libs                  10.1.0

代码如下

import tensorrt as trt

logger = trt.Logger(trt.Logger.INFO)

with open("/imm_alg_model/svtrv2_dynamicBW_fp16_lean_t4.trt", "rb") as f, trt.Runtime(logger) as runtime:
    engine = runtime.deserialize_cuda_engine(f.read())

print("TensorRT version:", trt.__version__)

# ------------------------------------------------------------------
# 1. TensorRT 9.x / 10.x  ——  使用 I/O-tensor 接口
# ------------------------------------------------------------------
if hasattr(engine, "num_io_tensors"):      # TRT 9+ 才有
    print("=== I/O tensors (new API) ===")
    for i in range(engine.num_io_tensors):
        name  = engine.get_tensor_name(i)
        mode  = engine.get_tensor_mode(name)        # trt.TensorIOMode.INPUT / OUTPUT
        io    = "Input " if mode == trt.TensorIOMode.INPUT else "Output"
        dtype = engine.get_tensor_dtype(name)
        shape = engine.get_tensor_shape(name)       # 如果是动态维,会出现 -1
        print(f"{io:6}: {name:<30} {shape} {dtype}")

    # 如果模型是动态 shape,还想看 profile 范围,可以这样:
    if engine.num_optimization_profiles > 0:
        profile = 0
        print(f"\n[Profile {profile}] min/opt/max shapes")
        for i in range(engine.num_io_tensors):
            name  = engine.get_tensor_name(i)
            mode  = engine.get_tensor_mode(name)
            if mode != trt.TensorIOMode.INPUT:   # 通常只关心输入
                continue
            shapes = engine.get_tensor_profile_shape(name, profile)  # 返回 (min, opt, max)
            print(f"{name:<30} : {shapes}")

# ------------------------------------------------------------------
# 2. TensorRT 8.x 及更早版本 —— 使用 binding 接口
# ------------------------------------------------------------------
else:
    print("=== Bindings (old API) ===")
    for i in range(engine.num_bindings):
        name  = engine.get_binding_name(i)
        io    = "Input " if engine.binding_is_input(i) else "Output"
        dtype = engine.get_binding_dtype(i)
        shape = engine.get_binding_shape(i)
        print(f"{io:6}: {name:<30} {shape} {dtype}")

    if engine.num_optimization_profiles > 0:
        profile = 0
        print(f"\n[Profile {profile}] min/opt/max shapes")
        for i in range(engine.num_bindings):
            if not engine.binding_is_input(i):
                continue
            shapes = engine.get_profile_shape(profile, i)  # 返回 (min, opt, max)
            print(f"{engine.get_binding_name(i):<30} : {shapes}")

posted @ 2025-08-01 10:15  叶子喧闹  阅读(17)  评论(0)    收藏  举报