神经网络性能评估
1 torchstat
该工具包可通过pip直接安装:
pip install torchstat
使用方法
import torchvision.models as models
#pretrained=True就可以使用预训练的模型
#resnet18 = models.resnet18(pretrained=True)
resnet18 = models.resnet18()
from torchstat import stat
# 第一个参数为待分析的模型,另一个参数表示输入图片的大小
stat(resnet18, (3, 224, 224))
分析的效果如下:
从分析结果可以看出,torchstat的功能非常强大,不仅可以实现FLOPs、参数量、MAdd、显卡内存占用量等模型参数的分析,还可以看到模型每一层的分析结果,工具包不支持的layer也会列在分析结果前提醒使用者。
虽然torchstat的功能十分强大,但是也有一些缺陷:
-
限制模型输入仅能为图片
-
限制模型每一个layer的输入须为单个变量
-
对Pytorch-0.4.1及以下版本的支持不足
以上这些缺陷是在实践中发现的,具体表现为程序报错。如果修改模型也无法适配torchstat,这时就要考虑另选分析工具。
2 thop
对于torchstat无法适用的模型某一个layer的输入为多个变量和Pytorch-0.4.1版本等情况,可以尝试使用thop工具包进行模型分析。
安装
pip install thop
thop工具包相对torchstat而言,功能较为简单,仅支持FLOPs和参数量的计算(或者是我没有发现,不过我看源码是只返回这俩参量)。thop工具包的使用方法如下
from thop import profile
from thop import clever_format
import torchvision.models as models
resnet18 = models.resnet18()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(resnet18, inputs=(input, ))
print(flops,params)
flops,params = clever_format([flops, params],"%.3f")
print(flops,params)
结果如下:
推荐首选torchstat进行模型分析,如果出现无法解决的程序报错,再尝试使用thop
3 ptflops
使用方法如下:
from ptflops import get_model_complexity_info
import torchvision.models as models
resnet18 = models.resnet18()
flops, params = get_model_complexity_info(resnet18, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
print('flops: ', flops, 'params: ', params)
结果如下:
4 pytorch_model_summary
安装
pip install pytorch_model_summary
使用
from pytorch_model_summary import summary
import torchvision.models as models
resnet18 = models.resnet18()
nc, nh, nw = 3, 513, 513
batch_size = 1 # 批处理大小
input_shape = (nc, nh, nw) # 输入数据
# set the model to inference mode
resnet18.eval()
inputdata = torch.randn(1, *input_shape) # 生成张量
print(summary(resnet18, inputdata, show_input=False, show_hierarchical=False))

浙公网安备 33010602011771号