Pytorch 统计模型参数量 param.numel()
total_params = 0
trainable_params = 0
for param in self.lane_head.parameters():
param_count = param.numel() # 获取单个参数的元素个数
total_params += param_count
if param.requires_grad:
trainable_params += param_count
# 格式化输出(转换为 M/千,保留2位小数)
def format_params(num: int) -> str:
if num >= 1e6:
return f"{num / 1e6:.2f}M"
elif num >= 1e3:
return f"{num / 1e3:.2f}K"
return f"{num}"
print(f"=== 模型参数量统计 ===")
print(f"总参数量: {format_params(total_params)} ({total_params:,} 个)")
print(f"可训练参数量: {format_params(trainable_params)} ({trainable_params:,} 个)")
print(f"不可训练参数量: {format_params(total_params - trainable_params)}")