Pytorch 统计模型参数量 param.numel()【转】

Pytorch 统计模型参数量 param.numel()

 

        total_params = 0
        trainable_params = 0
        
        for param in self.lane_head.parameters():
            param_count = param.numel()  # 获取单个参数的元素个数
            total_params += param_count
            if param.requires_grad:
                trainable_params += param_count

        # 格式化输出(转换为 M/千,保留2位小数)
        def format_params(num: int) -> str:
            if num >= 1e6:
                return f"{num / 1e6:.2f}M"
            elif num >= 1e3:
                return f"{num / 1e3:.2f}K"
            return f"{num}"

        print(f"=== 模型参数量统计 ===")
        print(f"总参数量: {format_params(total_params)} ({total_params:,} 个)")
        print(f"可训练参数量: {format_params(trainable_params)} ({trainable_params:,} 个)")
        print(f"不可训练参数量: {format_params(total_params - trainable_params)}")

 

posted @ 2020-12-01 16:30  Picassooo  阅读(234)  评论(0)    收藏  举报