pytorch中的numel函数
1. numel函数用于获取tensor中一共包含多少个元素
import torch
x = torch.randn(3,3)
print("number elements of x is ",x.numel())
y = torch.randn(3,10,5)
print("number elements of y is ",y.numel())
输出:
number elements of x is 9
number elements of y is 150
27和150分别位x和y中各有多少个元素或变量
2. 统计模型参数量
num_params = sum(param.numel() for param in net.parameters())
print(num_params)