快速统计 PyTorch 模型参数量

使用 .numel() 方法可以统计模型的参数量。

以下代码摘自 In-context Autoencoder (ICAE) 的代码仓库。输入 nn.Module,统计该模型的参数量和可训练参数量:

def print_trainable_parameters(model):
    trainable_parameters = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_parameters += param.numel()
    print(f"trainable params: {trainable_parameters} || all params: {all_param} || trainable%: {100 * trainable_parameters / all_param}")
posted @ 2024-06-28 10:48  倒地  阅读(270)  评论(0)    收藏  举报