神经网络稀疏训练
对于每一个通道都引入一个缩放因子,然后和通道的输出相乘。接着联合训练网络权重和这些缩放因子,最后将小缩放因子的通道直接移除,微调剪枝后的网络,特别地,目标函数被定义为:
其中(x,y)代表训练数据和标签,是网络的可训练参数,第一项是CNN的训练损失函数。是在缩放因子上的乘法项,是两项的平衡因子。论文的实验过程中选择,即正则化,这也被广泛的应用于稀疏化。次梯度下降法作为不平滑(不可导)的L1惩罚项的优化方法,另一个建议是使用平滑的L1正则项取代L1惩罚项,尽量避免在不平滑的点使用次梯度。
这里的缩放因子就是BN层的gamma参数。
在train.py的实现中支持了稀疏训练,其中下面这2行代码即添加了稀疏训练的稀疏系数,注意是作用在BN层的缩放系数上的:
parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',
help='train with channel sparsity regularization')
parser.add_argument('--s', type=float, default=0.01, help='scale sparse rate')
class BNOptimizer():
@staticmethod
def updateBN(sr_flag, module_list, s, prune_idx):
if sr_flag:
for idx in prune_idx:
# Squential(Conv, BN, Lrelu)
bn_module = module_list[idx][1]
bn_module.weight.grad.data.add_(s * torch.sign(bn_module.weight.data)) # L1