TensorRT模型查看input shape
Requirement
以下代码经测试,在TensoRT 10.1版本中可以正常运行。
pip安装信息如下:
tensorrt-cu11 10.1.0
tensorrt-cu11-bindings 10.1.0
tensorrt-cu11-libs 10.1.0
tensorrt-lean-cu11 10.1.0
tensorrt-lean-cu11-bindings 10.1.0
tensorrt-lean-cu11-libs 10.1.0
代码如下
import tensorrt as trt
logger = trt.Logger(trt.Logger.INFO)
with open("/imm_alg_model/svtrv2_dynamicBW_fp16_lean_t4.trt", "rb") as f, trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
print("TensorRT version:", trt.__version__)
# ------------------------------------------------------------------
# 1. TensorRT 9.x / 10.x —— 使用 I/O-tensor 接口
# ------------------------------------------------------------------
if hasattr(engine, "num_io_tensors"): # TRT 9+ 才有
print("=== I/O tensors (new API) ===")
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
mode = engine.get_tensor_mode(name) # trt.TensorIOMode.INPUT / OUTPUT
io = "Input " if mode == trt.TensorIOMode.INPUT else "Output"
dtype = engine.get_tensor_dtype(name)
shape = engine.get_tensor_shape(name) # 如果是动态维,会出现 -1
print(f"{io:6}: {name:<30} {shape} {dtype}")
# 如果模型是动态 shape,还想看 profile 范围,可以这样:
if engine.num_optimization_profiles > 0:
profile = 0
print(f"\n[Profile {profile}] min/opt/max shapes")
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
mode = engine.get_tensor_mode(name)
if mode != trt.TensorIOMode.INPUT: # 通常只关心输入
continue
shapes = engine.get_tensor_profile_shape(name, profile) # 返回 (min, opt, max)
print(f"{name:<30} : {shapes}")
# ------------------------------------------------------------------
# 2. TensorRT 8.x 及更早版本 —— 使用 binding 接口
# ------------------------------------------------------------------
else:
print("=== Bindings (old API) ===")
for i in range(engine.num_bindings):
name = engine.get_binding_name(i)
io = "Input " if engine.binding_is_input(i) else "Output"
dtype = engine.get_binding_dtype(i)
shape = engine.get_binding_shape(i)
print(f"{io:6}: {name:<30} {shape} {dtype}")
if engine.num_optimization_profiles > 0:
profile = 0
print(f"\n[Profile {profile}] min/opt/max shapes")
for i in range(engine.num_bindings):
if not engine.binding_is_input(i):
continue
shapes = engine.get_profile_shape(profile, i) # 返回 (min, opt, max)
print(f"{engine.get_binding_name(i):<30} : {shapes}")

浙公网安备 33010602011771号