pocket_flow通道剪枝--源代码分析

pocketflow 代码地址:https://github.com/Tencent/PocketFlow

通道剪枝运行命令:

./scripts/run_seven.sh nets/resnet_at_cifar10_run.py \
    --learner channel \
    --batch_size_eval 64 \
    --cp_uniform_preserve_ratio 0.5 \
    --cp_prune_option uniform \
    --resnet_size 20

  本文先介绍 Uniform Channel Pruning 的实现:在learners\channel_pruning 中的learner.py实现了模型的通道剪纸,主要代码在train中:

  

通道剪枝支持三种:list是支持对每一层通道都设置特定的裁剪系数,auto:是通过rl 增强学习(怎么实现的)计算出 每一层的裁剪系数,然后对权重裁剪。uniform是对所有卷积层都采用同一个裁剪系数。下面主要介绍uniform这种裁剪类型

其中extract_features中计算卷积层的输出,卷积层的输出数据时(*,H,W,C),从H和W每个维度随机抽样cp_nb_points_per_layer 份数据。compress对不是第一层和最后一层的卷积层进行压缩裁剪;,compress中裁剪的实现在 prune_kernel中,

其中extract_input 获取到该operation的输入,输入的结果会在h和w通道进行sample采样获取到NewX。同时根据feats_dict获取到卷积层的输出Y,其中会对resnet的block中的最后一个卷积层,添加额外的残差输入。

如果设置不使用lasso线性回归,则首先对W2权重按照绝对值累加,以channel通道为轴进行累加,然后根据累加值对N channel通道进行排序并对小的通道值裁剪,最后基于裁剪后权重卷积,用LinearRegression 线性回归生成新的权重 newW2.

如果设置使用 LassoLars 线性回归,则通过 LassoLars 线性回归确定需要裁剪的通道,然后用LinearRegression 线性回归生成新的权重 newW2.

 

posted on 2020-09-04 08:41  xgcode  阅读(350)  评论(0编辑  收藏  举报