神经网络稀疏训练

对于每一个通道都引入一个缩放因子,然后和通道的输出相乘。接着联合训练网络权重和这些缩放因子,最后将小缩放因子的通道直接移除,微调剪枝后的网络,特别地,目标函数被定义为:

在这里插入图片描述
其中(x,y)代表训练数据和标签,是网络的可训练参数,第一项是CNN的训练损失函数。是在缩放因子上的乘法项,是两项的平衡因子。论文的实验过程中选择,即正则化,这也被广泛的应用于稀疏化。次梯度下降法作为不平滑(不可导)的L1惩罚项的优化方法,另一个建议是使用平滑的L1正则项取代L1惩罚项,尽量避免在不平滑的点使用次梯度。

这里的缩放因子就是BN层的gamma参数。

在train.py的实现中支持了稀疏训练,其中下面这2行代码即添加了稀疏训练的稀疏系数,注意是作用在BN层的缩放系数上的:

parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',
                        help='train with channel sparsity regularization')
parser.add_argument('--s', type=float, default=0.01, help='scale sparse rate') 

class BNOptimizer():

    @staticmethod
    def updateBN(sr_flag, module_list, s, prune_idx):
        if sr_flag:
            for idx in prune_idx:
                # Squential(Conv, BN, Lrelu)
                bn_module = module_list[idx][1]
                bn_module.weight.grad.data.add_(s * torch.sign(bn_module.weight.data))  # L1

link

posted @ 2022-08-19 22:46  luoganttcc  阅读(101)  评论(0)    收藏  举报