快速统计 PyTorch 模型参数量
使用 .numel() 方法可以统计模型的参数量。
以下代码摘自 In-context Autoencoder (ICAE) 的代码仓库。输入 nn.Module,统计该模型的参数量和可训练参数量:
def print_trainable_parameters(model):
trainable_parameters = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_parameters += param.numel()
print(f"trainable params: {trainable_parameters} || all params: {all_param} || trainable%: {100 * trainable_parameters / all_param}")

浙公网安备 33010602011771号