Pytorch怎么进行网络剪枝以降低显存占用?

这个问题其实在很多论文里都有提到。很多论文会说我这个网络在测试时只需要其中的某某部分,其他的可以通过剪枝剪掉,但在实际科研或者论文复现中(比如我上一篇的DSRL论文介绍),这个剪枝往往很少有人关注。

 

实际上这个如果真的要放到实际应用中还是很有用的,下面将以Pytorch为例介绍下到底怎么才叫剪枝。原文请参考:http://www.shijinglei.com/2020/05/26/pytorch-%E8%BD%BD%E5%85%A5%E9%A2%84%E8%AE%AD%E7%BB%83%E7%BD%91%E7%BB%9C%E9%83%A8%E5%88%86%E6%9D%83%E9%87%8D/

 

在开始之前,我们需要明确一点,也即这里的剪枝和科研界专门研究权重重要性的剪枝并不是一回事,在这里,我们手头上有一个包含完整网络的参数权重,而我们要做的是将这个权重成功load到测试时使用的轻量网络上。因此,我们实际上会有一个full_net及其对应的权重,以及一个在full_net上剪枝后的partial_net。明确了这一点,就可以开始下面的步骤了:

 

1. 首先weight_dict = torch.load(‘path_to_weight’),读取预训练网络的权重键值。

2. 然后获取当前网络的权重键值,因为Pytorch里的键值和你在代码写的变量名是不一样的。
model_dict = model.state_dict() #model为当前定义的网络

3. 最关键一步,根据键命名筛选出需要载入的部分权重。当前网络中要载入权重的部分,命名要与预训练网络相同
weight_dict = {k:v for k, v in weight_dict.items() if k in model_dict}

4. 更新当前网络的键值字典
model_dict.update(weight_dict)

5. 最后载入该键值字典到网络中
model.load_state_dict(model_dict)

 

这样就完成了整个过程,可以看到剪枝这个动作并不是在测试时动态执行的,而是意味着我们需要另外写一个新的网络结构出来。

 

posted @ 2020-12-29 18:48  Dotman  阅读(283)  评论(0)    收藏  举报