模型压缩

1 模型压缩的概述

深度学习因其计算复杂度或参数见余,在一些场景和设备上限制了相应的模型部署,需要借助模型压缩、优化加速、异构计算等方法突破瓶颈。模型压缩算法能够有效降低参数余,从而减少存储占用、通信带宽和计算复杂度,有助于深度学习的应用部署,具体可划分为如下几种方法:

  • 剪枝(Pruning):剪枝技术通过移除网络中不重要的参数或神经元来减少模型的大小和计算复杂度。
  • 知识蒸馏(Knowledge Distillation):知识蒸馏技术通过将复杂模型(教师模型)的知识传递给较小的模型(学生模型)来进行模型压缩。
  • 量化(Quantization):量化技术通过减少模型参数和计算的数值精度来压缩模型。
  • 网络结构简化:网络结构简化是指通过设计更高效的网络结构来减少计算复杂度和参数数量。

2 剪枝

2.1 什么是剪枝

神经网络在训练过程中为了捕捉数据中的复杂模式,通常会过参数化,这导致模型拥有大量对最终输出贡献不大的冗余参数。而模型剪枝就是指从深度学习神经网络模型中删除冗余参数(即将其置为 0)的技术,它能够在不牺牲太多准确性的前提下压缩模型大小并提高模型的推理速度。

一般来说,只有权重参数(weight)被修剪,而偏置参数(bias)保持不变,因为对偏置的修剪往往会带来更大的负面影响。

image

主要有两种方式来实现模型剪枝,即训练时剪枝和训练后剪枝:

  • 训练时剪枝(train-time pruning)
    训练时剪枝是在模型训练过程中动态进行剪枝操作的方法。这种方法的核心思想是在训练过程中逐步剪枝,从而使模型在训练结束时已经是一个剪枝后的模型。其优点是剪枝过程与模型训练过程紧密结合,能够在训练结束时直接得到剪枝后的模型,避免了额外的后续微调步骤,节省了训练时间和计算资源。缺点是在训练过程中剪枝可能导致模型训练的不稳定,需要精细调整超参数和剪枝策略,增加了实现的复杂性。

  • 训练后剪枝(post-training pruning)
    训练后剪枝是在模型完成训练后对其进行剪枝的方法。先训练一个完整的模型,然后在训练结束后对模型进行剪枝,再进行微调以恢复性能。其优点是剪枝步骤相对独立,无需在训练过程中动态调整剪枝操作,可以根据最终模型的需求选择最优的剪枝策略,再进行微调,以达到更好的性能恢复。缺点是剪枝后可能需要多次微调,才能恢复模型性能。

在上述两种实现方式中,训练后剪枝更为常用,因为它在实施上简单直接,不需要在训练过程中动态调整剪枝操作,剪枝步骤相对独立且容易实施。

2.1 剪枝的分类

  1. 根据剪枝的粒度分,剪枝方法可以大致分为以下六种:
    image

    上图展示了在输入通道(即卷积核通道数)是 3,输出通道(即卷积核个数)是 2 的卷积层上进行六种不同粒度剪枝操作的效果。

    • Fine-grained Pruning(细粒度剪枝)
      细粒度剪枝是指在最细粒度上进行剪枝,即删除某些单个的权重值。这种方法能够在较少的准确率损失下实现更高的剪枝率,但缺点是剪枝后的稀疏矩阵是不规则的(即非零元素的分布是不规则的),这不利于压缩存储和加速计算,需要专门的软硬件支持。
    • Vector-level Pruning(向量级剪枝)
      向量级剪枝是指在向量这一层次上进行参数移除,例如一个神经元的输出权重向量或卷积层中一个滤波器的权重向量。这比细粒度剪枝的范围更大,能够减少剪枝后矩阵的不规则性,便于硬件加速,但在牺牲少量模型表达能力的同时,可能不如细粒度剪枝那样精细地优化模型大小和性能。
    • Kernel-level Pruning(核级剪枝)
      核级剪枝是针对卷积神经网络中的卷积核进行的,即整体移除一个卷积核的所有权重。减少了计算量和模型尺寸,但可能会影响模型捕获细节的能力,因为整个特征检测器被移除。
    • Filter-level Pruning(滤波器级剪枝)
      实质上与核级剪枝相似,滤波器级剪枝专注于从卷积层中移除完整的滤波器。这一策略简化了网络结构,显著降低了计算成本和内存使用,不过可能需要更多的调参来维持原有性能。
    • Channel-level Pruning(通道级剪枝)
      通道级剪枝针对具有通道结构的层,比如卷积层中的通道,移除整个通道的特征映射。这不仅减少了参数量,还可能因通道间潜在的冗余而减少信息重复,但要求对通道重要性有精准的评估,以免损害模型的表达能力。
    • Layer-level Pruning(层级剪枝)
      层级剪枝是最粗粒度的剪枝策略,直接移除整个网络层。这种方法极大地简化了模型结构,加快了推理速度,降低了资源消耗,但需谨慎执行,以防过度简化导致性能急剧下降,需深入分析每一层对最终任务的重要性。
  2. 按照剪枝后的稀疏矩阵是否规则,剪枝算法也可以分为:

    • 非结构化化剪枝(Unstructured Pruning)
      即细粒度剪枝。非结构化稀疏具有更高的模型压缩率和准确性,在通用硬件上的加速效果不好。因为其计算特征上的“不规则”,导致需要特定硬件支持才能实现加速效果。
    • 结构化剪枝(Structured Pruning)
      即向量级、核级、滤波器级、通道级以及层级剪枝。与非结构化剪枝相比,结构化剪枝通常通常会牺牲模型的准确率和压缩比。结构化稀疏对非零权值的位置进行了限制,在剪枝过程中会将一些数值较大的权值移除,从而影响模型准确率。结构化稀疏虽然牺牲了模型压缩率或准确率,但在通用硬件上的加速效果好,所以其被广泛应用。因为结构化稀疏使得权值矩阵更规则更加结构化,更利于硬件加速。

    image

  3. 另外,按照不同的修剪范围,剪枝算法还可以分为:

    • 局部剪枝(Local Pruning)
      局部剪枝是指在模型的特定部分或局部区域进行参数的删除,它允许更精细地控制模型的哪个部分将被简化。这种方法基于这样的假设:模型的不同部分对最终性能的贡献不同,通过有针对性地剪枝某些局部区域,可以更有效地减少计算量和内存占用,同时尽量保持模型的关键性能。
    • 全局剪枝(Global Pruning)
      全局剪枝则是指在整个模型范围内统一进行参数的剪除,不特别区分模型的不同部分。这种方法会基于某种全局性标准(例如基于权值的 L1 范数)来决定哪些参数被剪枝。全局剪枝简化了剪枝过程的复杂度,因为它应用的是统一的剪枝比例或标准,但可能不够灵活,无法充分考虑到模型内部不同部分的差异性对性能的影响。全局剪枝的一个直接结果是模型的整体稀疏度一致,这有时可能导致某些关键区域的性能下降。

2.3 PyTorch 剪枝

2.3.1 PyTorch 剪枝的原理

PyTroch 中神经网络模型或模块类都是继承自 torch.nn.Module 类的,该类中和剪枝实现相关的关键属性包括 _modules_parameters 以及 _buffers。以下是对这些属性的详细介绍:

  • _parameters_parameters 是一个 OrderedDict 对象,用于存储当前模块的所有可学习参数(如 weight 和 bias)及其参数名称,这些参数会在训练过程中被优化器更新。
  • _buffers_buffers 是一个 OrderedDict 对象,用于存储当前模块的所有非学习参数(如 BN 层的 running mean 和 running variance,或者存储剪枝掩码)。这些参数不会在训练过程中被优化器更新,而是用于存储模型状态的一部分。
  • _forward_pre_hooks_forward_pre_hooks 是一个 OrderedDict 对象,用于存储在前向传播前需要调用的钩子(hook)函数及其序号,这些钩子函数会在模型的前向传播前执行一些自定义操作。

并且上述这些属性中的元素都可以通过 del 方式删除以及使用模块对象提供的方法进行查询和添加。

另外,若模块(如 nn.Linear、nn.Conv2d)拥有训练参数,则这些训练参数不仅会存储在 _parameters 中,也会存储在模块对象的 weight 以及 bias 属性中,这些属性是 torch.Tensor 对象(实际上是 torch.nn.parameter.Parameter 类型,它 torch.Tensor 类型的子类):

  • weight:weight 是一个张量,表示模块的权重参数。它用于在前向传播时与输入进行矩阵进行线性运算。
  • bias:bias也是一个张量,表示模块的偏置参数。它在前向传播过程中应用于加到线性变换结果上的偏置值。

掌握了以上预备知识后,继续介绍 PyTorch 剪枝的实现原理,首先定义一个简单的神经网络模型便于演示,以 LeNet 网络为例:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import prune


class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

为了方便进行研究,我们只对上述网络中 fc3 线性层模块的权重参数(weight)进行研究,首先获取代表 fc3
的模块对象:

model = LeNet()
module = model.fc3

接下来对剪枝不同阶段时,个属性参数的变化进行介绍:

  1. 剪枝前

    • module 的 weight 属性和 _parameters 属性中的 weight 参数其实是同一个张量对象。也就是说剪枝前的反向传播会同时更新 weight 属性和 _parameters 属性中的 weight 参数。例如我们输出它们的对象标识:

      print(id(module._parameters['weight']))
      print(id(module.weight))
      
      点击查看输出结果
      1769318681312
      1769318681312
      
    • 由于还未进行剪枝,所以 model 的 _buffers 缓冲区暂时是空的(这是由于还未在其中生成用于剪枝的掩码张量)。例如我们输出 _buffers 的长度:

      print(len(module._buffers))
      
      点击查看输出结果
      0
      
  2. 剪枝时
    现在对 module 的权重参数进行随机非结构化剪枝,代码如下:

    prune.random_unstructured(module, name="weight", amount=0.5)
    

    剪枝时,PyTorch 会具体执行以下步骤:

    • 在 module 的 _parameters 属性中新增一个名为 weight_orig 的参数,该参数存储了剪枝前的原始权重张量。并且会删除 _parameters 属性中原来的 weight 参数(即将原始权重张量的新引用 weight_orig 添加进了 _parameters,并将旧引用 weight 从 _parameters 中删除了)。例如输出剪枝前后 _parameters 中的参数:

      print("Before pruning:")
      [print(k, v.shape, id(v)) for k, v in module._parameters.items()]
      prune.random_unstructured(module, name="weight", amount=0.5)
      print("After pruning:")
      [print(k, v.shape, id(v)) for k, v in module._parameters.items()]
      
      点击查看输出结果
      Before pruning:
      weight torch.Size([10, 84]) 1769318681312
      bias torch.Size([10]) 1769318681392
      After pruning:
      bias torch.Size([10]) 1769318681392
      weight_orig torch.Size([10, 84]) 1769318681312
      
    • 会在 module 的 _buffers 属性中创建一个与权重参数形状相同的掩码张量 weight_mask,剪枝的位置在掩码中被设置为零。例如输出剪枝后的 _buffers 属性:

      print(module._buffers)
      
      点击查看输出结果
      OrderedDict([('weight_mask', tensor([[0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1., 0., 0., 1., 0.,
      		 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1.,
      		 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 0., 1., 1., 0., 1.,
      		 1., 1., 0., 1., 0., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0.,
      		 1., 0., 1., 0., 0., 0., 1., 1., 1., 0., 1., 0.],
      		[1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0.,
      		 0., 0., 1., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
      		 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1.,
      		 1., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 1., 1., 0., 1., 0., 0., 0.,
      		 0., 1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0.],
      		[0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 1., 1., 0., 0., 1., 0.,
      		 0., 0., 0., 1., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0.,
      		 1., 0., 0., 0., 1., 0., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 1., 1.,
      		 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1.,
      		 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 1., 1.],
      		[1., 1., 0., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.,
      		 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0.,
      		 0., 0., 1., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 1., 1.,
      		 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0.,
      		 1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1.],
      		[0., 0., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 0.,
      		 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0.,
      		 0., 0., 1., 0., 0., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 0.,
      		 0., 1., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1.,
      		 0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0.],
      		[1., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0.,
      		 0., 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0.,
      		 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1.,
      		 1., 0., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1.,
      		 1., 1., 0., 0., 1., 0., 1., 1., 1., 1., 1., 0.],
      		[0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0., 1.,
      		 0., 1., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 1.,
      		 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 0.,
      		 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1.,
      		 1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.],
      		[0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 1., 1., 1.,
      		 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1.,
      		 0., 0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0.,
      		 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0.,
      		 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.],
      		[0., 0., 0., 1., 0., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 1., 1., 0.,
      		 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 0., 1., 0., 0., 0.,
      		 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1.,
      		 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0.,
      		 1., 0., 0., 1., 0., 1., 0., 1., 1., 1., 1., 0.],
      		[0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1., 0., 1., 1., 0., 0., 1.,
      		 1., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1.,
      		 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 1., 1., 0., 0.,
      		 1., 1., 1., 0., 0., 1., 0., 0., 1., 1., 1., 0., 0., 1., 0., 0., 0., 0.,
      		 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.]]))])
      
    • 对 module 的 weight 属性进行重参数化,即将掩码应用到参与前向传播的权重参数上(module.weight = weight_mask * weight_orig)。例如输出 weight 属性:

      print(module.weight)
      
      点击查看输出结果
      tensor([[ 0.0000, -0.0000, -0.0918, -0.0204,  0.0444, -0.0631,  0.0495,  0.0694,
      		  0.0566, -0.0000,  0.0642, -0.0855, -0.0000, -0.0551, -0.0000, -0.0000,
      		  0.0159,  0.0000,  0.0345, -0.0151, -0.0336, -0.0000, -0.0944, -0.0476,
      		  0.0000, -0.0079,  0.0647,  0.0000,  0.0000,  0.0558, -0.0000, -0.0190,
      		  0.0000,  0.0000, -0.0000, -0.0531,  0.0000, -0.0569,  0.0000,  0.0000,
      		  0.0000,  0.0567,  0.0818, -0.0130,  0.0000, -0.0915, -0.0170, -0.0000,
      		 -0.0000,  0.0000,  0.0291, -0.0265, -0.0000,  0.0054,  0.0695,  0.1014,
      		  0.0000, -0.0581, -0.0000, -0.1043, -0.0000, -0.0000, -0.0000,  0.0153,
      		 -0.0820, -0.0215,  0.0000, -0.0000,  0.0000, -0.0000,  0.0226, -0.0000,
      		 -0.0038, -0.0000,  0.0826,  0.0000, -0.0000,  0.0000,  0.0801,  0.0692,
      		  0.1081, -0.0000,  0.0274, -0.0000],
      		[-0.0118,  0.0934, -0.0783,  0.0000, -0.0461, -0.0000, -0.0000, -0.0000,
      		 -0.0000, -0.0000,  0.0507, -0.1090, -0.0925, -0.0000, -0.0328,  0.0000,
      		  0.0000, -0.0000,  0.0000,  0.0000, -0.0331, -0.0000,  0.0000,  0.0746,
      		 -0.0008, -0.0000,  0.1027,  0.0000, -0.0000,  0.0000,  0.0000,  0.0000,
      		 -0.0742, -0.0000, -0.0000,  0.0000,  0.0376, -0.0128,  0.0000,  0.0000,
      		  0.0241,  0.0357,  0.0000,  0.0615,  0.0000, -0.0000,  0.0000, -0.0507,
      		 -0.0000,  0.0603,  0.0000, -0.0000, -0.0904, -0.1081,  0.0695, -0.0842,
      		  0.1059, -0.0564,  0.0163,  0.0000,  0.0000, -0.0526,  0.0000,  0.0486,
      		 -0.0736,  0.0532, -0.0954, -0.0000,  0.0169, -0.0000, -0.0000, -0.0000,
      		 -0.0000, -0.0122,  0.1064,  0.0000,  0.0000,  0.0000, -0.0654,  0.0621,
      		 -0.0573, -0.0000, -0.0000,  0.0000],
      		[ 0.0000,  0.0493, -0.0809, -0.0512,  0.0000, -0.0360, -0.0000, -0.0853,
      		  0.1046,  0.0000,  0.0000, -0.0005, -0.0877, -0.0663,  0.0000,  0.0000,
      		 -0.0170, -0.0000, -0.0000,  0.0000,  0.0000,  0.0151, -0.0000,  0.0492,
      		 -0.0962,  0.0000,  0.0124, -0.0560,  0.0000, -0.0000,  0.0828, -0.0263,
      		 -0.0000, -0.0601, -0.0000,  0.0000, -0.0881, -0.0000, -0.0000, -0.0000,
      		  0.1087,  0.0000,  0.0315,  0.0481,  0.0000,  0.0924, -0.0000, -0.0197,
      		 -0.0410,  0.0000, -0.0000, -0.0438,  0.0222, -0.0307,  0.0000,  0.0722,
      		 -0.0826, -0.0000, -0.0000,  0.0500,  0.0000, -0.0000, -0.0320, -0.0962,
      		 -0.0980, -0.0178, -0.0000,  0.0381, -0.0935,  0.0000,  0.0000, -0.0200,
      		 -0.0000, -0.0000,  0.0525,  0.0000, -0.0000, -0.0694,  0.0000, -0.0194,
      		 -0.0020,  0.0000,  0.0022,  0.0729],
      		[-0.0244,  0.0955,  0.0000,  0.0299,  0.0676,  0.0000,  0.0000, -0.0000,
      		  0.0555,  0.0355, -0.0686, -0.0837,  0.0633, -0.0579,  0.0084, -0.0000,
      		 -0.0000,  0.0000,  0.0101,  0.0000, -0.0000, -0.0122, -0.0000,  0.0000,
      		 -0.0924,  0.0000,  0.0000,  0.0000, -0.0601,  0.0397, -0.0000,  0.0000,
      		  0.0379, -0.0708, -0.0000,  0.0000, -0.0000, -0.0000,  0.0626, -0.1044,
      		 -0.0000,  0.0000,  0.0917, -0.0000,  0.0000,  0.0353,  0.0000, -0.0000,
      		 -0.0820, -0.0000, -0.0540, -0.0533,  0.0868,  0.0234,  0.0000, -0.0339,
      		 -0.0000, -0.0000,  0.0427,  0.0000, -0.0077,  0.0464, -0.0000,  0.0000,
      		  0.0000,  0.0000,  0.0744,  0.0000, -0.1040,  0.0000, -0.0751, -0.0000,
      		  0.0102, -0.0715, -0.0684, -0.0000, -0.0000, -0.0000, -0.0606, -0.0323,
      		  0.0000,  0.0000, -0.0537,  0.1001],
      		[ 0.0000,  0.0000, -0.0856, -0.0000,  0.1075,  0.0418,  0.0171, -0.0000,
      		  0.0116,  0.0549, -0.0000,  0.0815, -0.0275, -0.0000, -0.0005, -0.0260,
      		  0.0000, -0.0000, -0.0000, -0.0365, -0.0000,  0.0896, -0.0000,  0.0000,
      		 -0.0348, -0.1037,  0.0373, -0.0230, -0.0309,  0.0000,  0.0000,  0.0688,
      		  0.0405, -0.0694,  0.0000, -0.0000, -0.0000, -0.0000,  0.0446, -0.0000,
      		  0.0000,  0.0510, -0.0000, -0.1036,  0.0250,  0.0464,  0.0000, -0.0671,
      		  0.0705,  0.0327, -0.0000,  0.0000, -0.0462, -0.0000,  0.0000,  0.0009,
      		 -0.0009, -0.0000,  0.0361, -0.0681, -0.0000, -0.0000,  0.0000,  0.0000,
      		 -0.0000, -0.0000,  0.0000, -0.0786, -0.0472,  0.0371, -0.0000, -0.0736,
      		 -0.0000,  0.0953,  0.0262,  0.0000,  0.0900,  0.0830,  0.0269,  0.0000,
      		 -0.0355, -0.0940,  0.0360, -0.0000],
      		[-0.0318, -0.0000, -0.0000, -0.0428, -0.0000,  0.0759,  0.0000,  0.0000,
      		 -0.0714,  0.0000, -0.0356,  0.0884, -0.0763,  0.0000,  0.0644,  0.0000,
      		  0.0000,  0.0000,  0.0000, -0.0000, -0.0844, -0.0262,  0.0000,  0.0000,
      		 -0.0104,  0.0000,  0.0926,  0.0505, -0.0399,  0.0000, -0.0000, -0.0132,
      		 -0.0622,  0.0954,  0.0000,  0.0000,  0.0000,  0.0000, -0.0031, -0.0000,
      		 -0.0000, -0.0000, -0.0000, -0.0517,  0.1019,  0.0196, -0.0557, -0.0567,
      		  0.0000,  0.0579, -0.0000,  0.0000,  0.0138, -0.0131,  0.0396, -0.0000,
      		  0.0000,  0.0752,  0.0716, -0.0000, -0.0718,  0.0000,  0.0552,  0.0000,
      		  0.0000,  0.0061, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0254,
      		  0.0562, -0.1015,  0.0000,  0.0000,  0.0073,  0.0000,  0.0176, -0.0265,
      		 -0.0376,  0.0343,  0.0167, -0.0000],
      		[-0.0000,  0.0080,  0.0806,  0.0645, -0.0000, -0.0000, -0.0024,  0.0252,
      		 -0.0404, -0.0000,  0.0770, -0.0000, -0.0000,  0.0000, -0.0110, -0.0252,
      		  0.0000, -0.0940,  0.0000,  0.0448,  0.0000, -0.0000,  0.0000, -0.0253,
      		  0.0765,  0.0000, -0.0022,  0.0411,  0.0000,  0.0000,  0.0958, -0.0000,
      		  0.0000,  0.0601,  0.0895, -0.0158, -0.0000, -0.0000,  0.0347,  0.0000,
      		  0.0000,  0.0903, -0.0000,  0.0000,  0.0585, -0.0000, -0.0539, -0.0000,
      		 -0.0000,  0.0627, -0.0006, -0.0580, -0.0069, -0.0000,  0.0992, -0.0628,
      		 -0.0237, -0.0812,  0.0508,  0.0000,  0.0000, -0.0000,  0.0611, -0.0760,
      		 -0.0000,  0.0026,  0.0326,  0.0917, -0.0000, -0.0000,  0.0865,  0.0780,
      		  0.0816, -0.0634,  0.0505,  0.0261, -0.0000, -0.0000, -0.0000,  0.0328,
      		  0.0480, -0.0000, -0.0000,  0.0000],
      		[ 0.0000,  0.0536, -0.0000,  0.0000, -0.0200, -0.0000, -0.0000,  0.0095,
      		 -0.0000,  0.0099,  0.0079,  0.0000, -0.0000, -0.0000, -0.0133, -0.0419,
      		 -0.0469, -0.0622, -0.0987, -0.0369, -0.0000,  0.0890,  0.0183,  0.0247,
      		 -0.0379, -0.0345, -0.0276,  0.0000,  0.0000, -0.0000,  0.0496, -0.0527,
      		 -0.0000,  0.0049, -0.0000, -0.1017, -0.0000,  0.0000, -0.0000,  0.0505,
      		  0.0000, -0.0000, -0.0352, -0.0912,  0.0067,  0.0960, -0.0000,  0.0000,
      		 -0.0026, -0.0685, -0.0000,  0.0288, -0.0646, -0.0000, -0.0000,  0.0000,
      		  0.0000, -0.0450, -0.0365, -0.0784,  0.0201, -0.0000,  0.0000,  0.0349,
      		 -0.0727, -0.0329,  0.0522,  0.0642, -0.0000, -0.0137,  0.0000, -0.0000,
      		 -0.0000,  0.0073,  0.0000,  0.0000, -0.0000, -0.0000,  0.0000,  0.0694,
      		 -0.0000,  0.0000, -0.0013,  0.0000],
      		[-0.0000,  0.0000, -0.0000, -0.0302,  0.0000,  0.0949, -0.0580,  0.0000,
      		 -0.0588,  0.0000, -0.0206,  0.0172, -0.0000,  0.0000, -0.0120, -0.0135,
      		 -0.0019,  0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0168, -0.0780,
      		 -0.0000,  0.0000,  0.0331,  0.0000, -0.0000, -0.0293, -0.0534,  0.0000,
      		  0.0502,  0.0000,  0.0000,  0.0000, -0.0998,  0.0000,  0.0742, -0.0147,
      		 -0.1032,  0.0853,  0.0503,  0.0846,  0.0750,  0.0162, -0.0000, -0.0994,
      		  0.0957,  0.1058,  0.0000,  0.0000,  0.0000, -0.0841, -0.0000,  0.0159,
      		 -0.0692,  0.0000,  0.0000, -0.0000, -0.0000,  0.1037, -0.0206, -0.0000,
      		 -0.0893,  0.0000, -0.0000,  0.0133, -0.0786, -0.0000, -0.0000, -0.0000,
      		  0.1054, -0.0000, -0.0000, -0.0729,  0.0000, -0.0152,  0.0000,  0.0149,
      		  0.0430, -0.0602,  0.0355, -0.0000],
      		[-0.0000,  0.0000,  0.0752, -0.0000,  0.0000,  0.0000, -0.0273, -0.0308,
      		 -0.0000, -0.0962,  0.0000,  0.0268,  0.0000,  0.0777,  0.0210, -0.0000,
      		 -0.0000,  0.0507,  0.0736, -0.0917, -0.0100,  0.0000,  0.0000,  0.0922,
      		  0.0657,  0.0000, -0.0000,  0.0000,  0.0682, -0.0926, -0.0000, -0.0000,
      		  0.0000, -0.0000, -0.0571,  0.0771, -0.0000,  0.0821,  0.0468,  0.0000,
      		 -0.0000, -0.0000,  0.0317, -0.0000, -0.0000,  0.0643, -0.0000, -0.0041,
      		  0.0000, -0.0000,  0.0845,  0.0621,  0.0000,  0.0000,  0.0761, -0.0151,
      		  0.0584, -0.0000, -0.0000,  0.0290, -0.0000, -0.0000,  0.0699,  0.0089,
      		  0.0859, -0.0000,  0.0000, -0.0684,  0.0000,  0.0000, -0.0000,  0.0000,
      		  0.0279,  0.0149, -0.0000, -0.0000, -0.0132, -0.0531,  0.0495, -0.0000,
      		  0.0000,  0.0000, -0.0000, -0.0000]], grad_fn=<MulBackward0>)
      
    • 会在 module 的 _forward_pre_hooks 属性中注册一个剪枝方法(其实是一个剪枝类对象,在此案例中是 RandomUnstructured 类对象),该剪枝方法会在每次前向传播之前被调用。例如输出 module 的 _forward_pre_hooks 属性:

      print(module._forward_pre_hooks)
      
      点击查看输出结果
      OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x00000174F3BEC280>)])
      
  3. 剪枝后进行前向传播时的
    在调用模型进行推理时,_forward_pre_hooks 中注册的剪枝方法会先被调用,使得 module 的 weight 属性再次重参数化,确保后续参与参与前向传播的权重参数是经过剪枝的稀疏权重。

  4. 剪枝后进行反向传播时
    由于前向传播过程中使用的是经过剪枝后的权重进行的计算,所以反向传播计算的梯度只会流向未被剪枝的权重,即只有未被剪枝的权重会得到更新。值得注意的是优化器操作的,优化器更新的是 _parameters 属性中的 weight_orig 参数,而 weight 属性中的参数只会在进行剪枝时或前向传播前进行重参数化更新。

2.3.2 PyTorch 的剪枝函数

PyTorch 剪枝功能在 torch.nn.utils.prune 模块中实现,主要的剪枝类如下图所示:
image

以上的剪枝类中 BasePruningMethod 类是 PyTorch 中剪枝模块的基础抽象类,它定义了剪枝方法所需的基本接口和属性,它不直接用于剪枝操作,而是供其他具体剪枝方法继承。其他具体的剪枝类从它继承而来的,根据需要进行了定制化和扩展,它们具体功能如下:

  • PruningContainer 类:是一个容器类,用于将多个剪枝方法组合在一起应用于模型的不同部分。它可以包含多个剪枝方法对象,统一管理和应用这些剪枝方法。
  • Identity 类:实现了在不修剪任何单元的情况下应用剪枝重参数化的功能,通常作为基准或占位符使用。
  • RandomUnstructured 类:实现了随机非结构化剪枝的功能。
  • L1Unstructured 类:实现了基于 L1 范数的非结构化剪枝的功能。
  • RandomStructured 类:实现了随机结构化剪枝的功能。
  • LnStructured 类:实现基于 Ln 范数的构化剪枝的功能。
  • CustomFromMask 类:实现了根据预定义的掩码对权重张量进行自定义结构化剪枝的功能。

虽然在 torch.nn.utils.prune 模块中定义了这些剪枝类,但我们通常不直接实例化和使用它们,而是通过 torch.nn.utils.prune 模块提供的便捷函数来进行剪枝操作。这些便捷函数对底层的剪枝类进行了封装,使剪枝操作更加简洁和易于使用。下面是对封装函数的详细介绍:

  1. identity 函数

    def identity(module, name)
    

    函数 identity 的作用是在不实际移除任何神经网络单元的情况下,对指定模块中的某个参数实施剪枝操作的重新参数化过程。此操作直接修改输入的 module 并将其返回。
    参数说明:

    • module(nn.Module):需要进行剪枝的神经网络模块。
    • name(str):模块中待进行剪枝的参数名称。

    返回值:

    • 经过剪枝后的模型模块。

  1. random_unstructured 函数
    函数 random_unstructured 作用是执行随机非结构化剪枝操作。从 module 中名为 name 的参数中随机移除指定数量的单元(当前未剪枝的)。此操作直接修改输入的 module 并将其返回。

    def random_unstructured(module, name, amount)
    

    参数说明:

    • module(nn.Module):需要进行剪枝的神经网络模块。
    • name(str):模块中待进行剪枝的参数名称。
    • amount(int 或 float):若为 float,应介于 0.0 到 1.0 之间,表示参数被剪枝的比例;若为 int,则直接表示剪枝的参数个数。

    返回值:

    • 经过剪枝后的模型模块。

  1. l1_unstructured 函数
    函数 l1_unstructured 作用是根据 L1 范数执行非结构化剪枝操作。从 module 中名为 name 的参数中移除具有最低 L1 范数的指定数量的单元(当前未剪枝的)。此操作直接修改输入的 module 并将其返回。

    def l1_unstructured(module, name, amount, importance_scores=None)
    

    参数说明:

    • module(nn.Module):需要进行剪枝的神经网络模块。
    • name(str):模块中待进行剪枝的参数名称。
    • amount(int 或 float):若为 float,应介于 0.0 到 1.0 之间,表示参数被剪枝的比例;若为 int,则直接表示剪枝的参数个数。
    • importance_scores(torch.Tensor, optional):形状与模块参数相同的重要性分数张量,用于确定剪枝掩码。如果不提供或为 None,则使用模块参数本身计算重要性。

    返回值:

    • 经过剪枝后的模型模块。

  1. random_structured 函数
    函数 random_structured 执行随机结构化剪枝操作。从 module 中名为 name 的参数中沿着指定的 dim 维度随机选择并移指定数量的结构(如向量、卷积核,通道等)。此操作直接修改输入的 module 并将其返回。

    def random_structured(module, name, amount, dim)
    

    参数说明:

    • module(nn.Module):需要进行剪枝的神经网络模块。
    • name(str):模块中待进行剪枝的参数名称。
    • amount(int 或 float):若为 float,应介于 0.0 到 1.0 之间,表示参数被剪枝的比例;若为 int,则直接表示剪枝的参数个数。
    • dim(int):指定要进行剪枝的维度索引。

    返回值:

    • 经过剪枝后的模型模块。

  1. ln_structured 函数
    函数 ln_structured 依据指定维度上的最低 Ln 范数执行结构化剪枝操作。移除选定数量的通道。从 module 中名为 name 的参数中沿着指定的 dim 维度中移除具有最低 Ln 范数的指定数量的结构(如向量、卷积核,通道等)。此操作直接修改输入的 module 并将其返回。

    def ln_structured(module, name, amount, n, dim, importance_scores=None)
    

    参数说明:

    • module(nn.Module):需要进行剪枝的神经网络模块。
    • name(str):模块中待进行剪枝的参数名称。
    • amount(int 或 float):若为 float,应介于 0.0 到 1.0 之间,表示参数被剪枝的比例;若为 int,则直接表示剪枝的参数个数。
    • n(int 或 float):规定了用于计算范数的类型,对应于 torch.norm 函数中的 p 参数,用于衡量权重的重要性。
    • dim(int):指定要进行剪枝的维度索引。
    • importance_scores(torch.Tensor, optional):形状与模块参数相同的重要性分数张量,用于确定剪枝掩码。如果不提供或为 None,则使用模块参数本身计算重要性。

    返回值:

    • 经过剪枝后的模型模块。

  2. global_unstructured 函数
    函数 global_unstructured 作用是执行全局非结构化剪枝操作,它的核心作用是在不考虑参数空间结构的情况下,基于全局重要性评估,统一决定哪些参数进行剪枝。这允许开发者跨多个层或模块统一剪枝标准,简化剪枝过程,尤其适合于想要快速减小模型大小而不过多考虑特定结构影响的情况。此操作直接修改输入的 module 并将其返回。

    def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs)
    

    参数说明:

    • parameters(Iterable of (module, name) tuples):这是一个包含多个 (module, name) 对的可迭代对象,其中 module 是 nn.Module 类型的神经网络层,name 是该层中参数的名称(通常是权重 weight 或偏置 bias)。这些参数将被全局地、非结构化地进行剪枝。
    • pruning_method(function):一个剪枝函数,用于确定进行非结构化剪枝的方式。内置的非结构化剪枝方法有:
      (1)L1Unstructured
      (2)RandomUnstructured
    • importance_scores(dict, optional): 一个字典,映射每个 (module, name) 对到一个同形状的重要性分数张量。这些分数用于决定剪枝哪些元素。如果不提供或为 None,则使用模块参数本身计算重要性。
    • kwargs:其他关键字参数,最常见的是 amount,用于指定剪枝的数量或比例。

    返回值:

    • 经过剪枝后的模型模块。

  1. custom_from_mask 函数
    custom_from_mask 函数允许用户使用预先计算好的掩码(mask)来对模块中的特定参数进行剪枝。这个函数主要用于自定义剪枝策略,当用户已经知道哪些元素应该被剪枝,或者从其他来源获取了剪枝掩码时使用。此操作直接修改输入的 module 并将其返回。

    def custom_from_mask(module, name, mask)
    

    参数说明:

    • module(nn.Module):需要进行剪枝的神经网络模块。
    • name(str):模块中待进行剪枝的参数名称。
    • mask(Tensor):二进制掩码张量,其形状与目标参数相同。

    返回值:

    • 经过剪枝后的模型模块。

  1. remove 函数
    remove 函数的作用是从给定的模块中移除之前应用的剪枝重新参数化,以及从前向传播钩子中移除剪枝方法。这意味着,虽然之前被剪枝的参数状态(按照剪枝掩码永久剪枝)会被保留,但是与剪枝相关的额外组件(原始参数和掩码)会被清理,使得模型更加简洁,以便于后续的部署或进一步处理。

    该操作并不“撤销”剪枝,也就是说,被剪枝的权重并不会被恢复。它只是移除了剪枝的痕迹,使得模型看起来像一个未经剪枝的模型,但是剪枝的效果(即参数的缺失)仍然存在。

    def remove(module, name)
    

    参数说明:

    • module(nn.Module): 需要移除剪枝设置的神经网络模块。
    • name(str):要移除剪枝参数的名称,该参数在之前已经被剪枝过。

    返回值:

    • None

  1. is_pruned 函数
    is_pruned 函数用于检查给定的神经网络模块是否已经应用了剪枝操作。它通过检测模块中是否存在继承自 BasePruningMethod 的前向传播预钩子来判断模块是否已被剪枝。

    def is_pruned(module)
    

    参数说明:

    • module(nn.Module): 需要检查是否经过剪枝的神经网络模块。

    返回值:

    • 果模块已应用剪枝,则返回 True;否则,返回 False。

2.4 剪枝的代码实现

2.4.1 导入必要的依赖

import os
import time

import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from matplotlib import pyplot as plt
from torch.nn.utils import prune
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torchvision.datasets import CIFAR10

2.4.2 定义训练早停类

class EarlyStopping:
    def __init__(self, patience=100, min_delta=0):
        if patience <= 0:
            raise ValueError('patience must be a positive integer')
        if min_delta < 0:
            raise ValueError('min_delta must be a positive number')

        self.patience = patience
        self.current_patience = patience
        self.min_delta = min_delta
        self.early_stop = False
        self.best_acc = 0
        self.best_loss = None
        self.best_epoch = None

    def __call__(self, val_loss, val_acc, epoch):
        if val_acc > self.best_acc + self.min_delta:
            self.current_patience = self.patience
            print(f'Epoch {epoch}: '
                  f"best acc improved from {self.best_acc:.4f} to {val_acc:.4f}, "
                  f'current patience reset to {self.patience}')
            self.best_acc = val_acc
            self.best_loss = val_loss
            self.best_epoch = epoch
        else:
            self.current_patience -= 1
            if self.current_patience <= 0:
                self.early_stop = True
                print(f"{self.patience} patience exhausted, early stopping! "
                      f"Best acc: {self.best_acc:.4f}, "
                      f"and best loss: {self.best_loss:.4f},"
                      f"appear at epoch {self.best_epoch}")
            else:
                print(f'Epoch {epoch}: '
                      f'best acc did not improve from {self.best_acc:.4f}, '
                      f'current patience: {self.current_patience}/{self.patience}')
        return self.early_stop

2.4.3 定义训练基础类

class BaseTrainer:
    def __init__(self,
                 num_epochs,
                 batch_size,
                 save_dir="weight",
                 device="cuda",
                 num_workers=0,
                 patience=10,
                 min_delta=0,
                 ):
        # 设置随机种子
        seed = 42
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        # 设置设备
        if device == "cuda" and torch.cuda.is_available():
            print("Using GPU for training")
            self.device = torch.device("cuda")
        else:
            if not torch.cuda.is_available():
                print("CUDA is not available, using CPU instead")
            else:
                print("Using CPU for training")
            self.device = torch.device("cpu")

        # 定义输出位置
        self.save_dir = save_dir
        if not os.path.exists(self.save_dir):  # 如果不存在该目录,则创建该目录
            os.makedirs(self.save_dir)

        # 加载模型
        self.model = self.load_model()
        self.model.to(self.device)

        # 定义训练参数
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.optimizer, self.scheduler = self.load_optimizer_and_scheduler()
        self.early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
        self.history = []
        self.train_loader, self.test_loader = self.get_data_loader()

    def get_data_loader(self):
        # 计算训练集的mean和std
        dataset = CIFAR10(root='data', train=True, download=True, transform=transforms.ToTensor())
        data_loader = DataLoader(dataset, batch_size=len(dataset), num_workers=self.num_workers)
        data = next(iter(data_loader))
        mean = torch.mean(data[0], dim=(0, 2, 3)).tolist()  # 计算均值
        std = torch.std(data[0], dim=(0, 2, 3)).tolist()  # 计算标准差

        # 定义数据预处理
        train_transform = transforms.Compose([  # 定义训练集的预处理
            transforms.RandomCrop(32, padding=4),  # 随机裁剪,增强模型对图像位置的鲁棒性
            transforms.RandomHorizontalFlip(),  # 随机水平翻转,增加数据的多样性
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动,增强模型对颜色变化的鲁棒性
            transforms.RandomRotation(15),  # 随机旋转,增加模型对旋转变换的鲁棒性
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        test_transform = transforms.Compose([  # 定义测试集的预处理
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        # 创建数据集和数据加载器
        train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transform)
        test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transform)
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return train_loader, test_loader

    def load_model(self):
        num_classes = 10
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        return model

    def load_optimizer_and_scheduler(self):
        lr = 0.01
        momentum = 0.9
        weight_decay = 1e-4
        optimizer = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.num_epochs)
        print(f"Optimizer: {optimizer.__class__.__name__}, lr={lr}, momentum={momentum}, weight_decay={weight_decay}")
        print(f"Scheduler: {scheduler.__class__.__name__}")
        print(f"You can change the optimizer and scheduler in the load_optimizer_and_scheduler method!")
        return optimizer, scheduler

    def save_model(self, filename):
        torch.save(self.model.state_dict(), os.path.join(self.save_dir, f"{filename}.pt"))

    def save_graph(self, epoch):
        # 展示训练过程
        history_np = np.array(self.history)
        plt.figure(figsize=(10, 10))  # 创建一个10*10的画布绘制损失曲线图
        plt.plot(np.arange(1, epoch + 1), (history_np[:, [0, 2]]))  # 画出训练集损失和验证集损失
        plt.legend(['Train Loss', 'Test Loss'])  # 显示图例
        plt.xlabel('Epoch')  # 设置x轴标签
        plt.ylabel('Loss')  # 设置y轴标签
        x_ticks = np.arange(0, epoch + 1, step=10)
        x_ticks[0] = 1  # epoch是从1开始的
        plt.xticks(x_ticks)  # 设置坐标轴刻度
        plt.yticks(np.arange(0, 2.05, 0.1))  # 设置坐标轴刻度
        plt.grid()  # 画出网格
        plt.gca().set_ylim([0, 2])  # 设置y轴范围
        best_test_loss_idx = np.argmin(history_np[:, 2])
        best_test_loss = history_np[best_test_loss_idx, 2]
        plt.text(
            0.1,
            1.05,
            f'Best test loss: {best_test_loss:.4f}, epoch: {best_test_loss_idx + 1}',
            ha='center',
            va='center',
            transform=plt.gca().transAxes
        )
        plt.savefig(os.path.join(self.save_dir, 'loss_curve.png'))  # 保存图片
        plt.close()  # 关闭画布

        plt.figure(figsize=(10, 10))  # 创建一个10*10的画布绘制准确率曲线图
        plt.plot(np.arange(1, epoch + 1), (history_np[:, [1, 3]]))  # 画出训练集准确率和验证集准确率
        plt.legend(['Train Accuracy', 'Test Accuracy'])  # 显示图例
        plt.xlabel('Epoch')  # 设置x轴标签
        plt.ylabel('Accuracy')  # 设置y轴标签
        x_ticks = np.arange(0, epoch + 1, step=10)
        x_ticks[0] = 1  # epoch是从1开始的
        plt.xticks(x_ticks)  # 设置坐标轴刻度
        plt.yticks(np.arange(0, 1.05, 0.05))  # 设置坐标轴刻度
        plt.grid()  # 画出网格
        best_test_acc_idx = np.argmax(history_np[:, 3])
        best_test_acc = history_np[best_test_acc_idx, 3]
        plt.text(
            0.1,
            1.05,
            f'Best test acc: {best_test_acc:.4f}, epoch: {best_test_acc_idx + 1}',
            ha='center',
            va='center',
            transform=plt.gca().transAxes
        )
        plt.savefig(os.path.join(self.save_dir, 'acc_curve.png'))  # 保存图片
        plt.close()  # 关闭画布

    def train(self, epoch):
        train_total_loss = 0
        train_total_acc = 0
        self.model.train()
        for images, labels in tqdm.tqdm(self.train_loader, total=len(self.train_loader),
                                        desc=f'Training Epoch {epoch}/{self.num_epochs}', delay=0.1):
            images = images.to(self.device)
            labels = labels.to(self.device)
            outputs = self.model(images)
            loss = F.cross_entropy(outputs, labels)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            train_total_loss += loss.item()
            train_total_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
        train_avg_loss = train_total_loss / len(self.train_loader)
        train_avg_acc = train_total_acc / len(self.train_loader.dataset)
        print(f'Epoch {epoch}/{self.num_epochs}, Train Loss: {train_avg_loss:.4f}, Train Acc: {train_avg_acc:.4f}')
        return train_avg_loss, train_avg_acc

    def test(self, epoch):
        test_total_loss = 0
        test_total_acc = 0
        self.model.eval()
        with torch.no_grad():
            for images, labels in tqdm.tqdm(self.test_loader, total=len(self.test_loader),
                                            desc=f'Testing Epoch {epoch}/{self.num_epochs}', delay=0.1):
                images = images.to(self.device)
                labels = labels.to(self.device)
                outputs = self.model(images)
                loss = F.cross_entropy(outputs, labels)
                test_total_loss += loss.item()
                test_total_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
        test_avg_loss = test_total_loss / len(self.test_loader)
        test_avg_acc = test_total_acc / len(self.test_loader.dataset)
        print(f'Epoch {epoch}/{self.num_epochs}, Test Loss: {test_avg_loss:.4f}, Test Acc: {test_avg_acc:.4f}')
        return test_avg_loss, test_avg_acc

    def run(self):
        start_time = time.time()  # 记录开始时间
        if self.history:
            raise Exception("This object has already run, please create a new object")
        for epoch in range(1, 1 + self.num_epochs):
            print(f"Current lr: {self.scheduler.get_last_lr()[0]:.6f}")
            train_loss, train_acc = self.train(epoch)
            test_loss, test_acc = self.test(epoch)
            self.history.append((train_loss, train_acc, test_loss, test_acc))
            self.scheduler.step()
            self.save_graph(epoch)
            self.save_model("last")
            if self.early_stopping(test_loss, test_acc, epoch):
                break
            if self.early_stopping.current_patience == self.early_stopping.patience:
                self.save_model("best")
            print()
        print(f"Training finished. "
              f"Best acc: {self.early_stopping.best_acc:.4f} "
              f"and best loss: {self.early_stopping.best_loss:.4f} "
              f"appear at epoch {self.early_stopping.best_epoch}")
        end_time = time.time()  # 记录结束时间
        period = round(end_time - start_time)
        print(f'All time: {period} s')
        print(f"All Time: {period // 60:d} min {period % 60:d} s")

2.4.4 定义剪枝训练类

class PruneTrainer(BaseTrainer):
    def __init__(self,
                 num_epochs,
                 batch_size,
                 save_dir="weight",
                 device="cuda",
                 num_workers=0,
                 patience=10,
                 min_delta=0
                 ):
        super().__init__(num_epochs, batch_size, save_dir, device, num_workers, patience, min_delta)

    def get_data_loader(self):
        # 计算训练集的mean和std
        dataset = CIFAR10(root='data', train=True, download=True, transform=transforms.ToTensor())
        data_loader = DataLoader(dataset, batch_size=len(dataset), num_workers=self.num_workers)
        data = next(iter(data_loader))
        mean = torch.mean(data[0], dim=(0, 2, 3)).tolist()  # 计算均值
        std = torch.std(data[0], dim=(0, 2, 3)).tolist()  # 计算标准差

        # 定义数据预处理
        train_transform = transforms.Compose([  # 定义训练集的预处理
            transforms.RandomCrop(32, padding=4),  # 随机裁剪,增强模型对图像位置的鲁棒性
            transforms.RandomHorizontalFlip(),  # 随机水平翻转,增加数据的多样性
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动,增强模型对颜色变化的鲁棒性
            transforms.RandomRotation(15),  # 随机旋转,增加模型对旋转变换的鲁棒性
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        test_transform = transforms.Compose([  # 定义测试集的预处理
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        # 创建数据集和数据加载器
        train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transform)
        test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transform)
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return train_loader, test_loader

    def load_model(self):
        num_classes = 10
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        return model

2.4.5 定义施加剪枝和移除剪枝的函数

def apply_pruning(model, prune_rate):
    for name, module in model.named_modules():
        # 对layer进行L1结构化剪枝
        if isinstance(module, torch.nn.Conv2d):  # 不对resent18的全连接层进行剪枝,因为是输出层
            if prune_rate > 0:
                prune.l1_unstructured(module, 'weight', amount=prune_rate)
            else:
                prune.identity(module, 'weight')


def remove_pruning(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            if prune.is_pruned(module):
                prune.remove(module, 'weight')

2.4.6 开始训练

if __name__ == '__main__':
    total_prune_rate = 0
    per_prune_rate = 0.5
    acc_decay_tol = 0.02
    prune_trainer = PruneTrainer(
        num_epochs=100,
        batch_size=256,
        patience=20,
        min_delta=0,
        num_workers=2,
        save_dir=f"prune/rate_{total_prune_rate}")
    apply_pruning(prune_trainer.model, 0)  # 将模型剪枝化
    prune_trainer.run()
    history_best_acc = prune_trainer.early_stopping.best_acc
    pre_weight = torch.load(f"prune/rate_{total_prune_rate}/best.pt")
    pre_total_prune_rate = 0
    pre_best_acc = prune_trainer.early_stopping.best_acc

    while total_prune_rate < 1:
        total_prune_rate += (1 - total_prune_rate) * per_prune_rate
        print(f"prune rate: {total_prune_rate}")
        prune_trainer = PruneTrainer(
            num_epochs=100,
            batch_size=256,
            patience=20,
            min_delta=0,
            num_workers=2,
            save_dir=f"prune/rate_{total_prune_rate}")
        apply_pruning(prune_trainer.model, 0)
        prune_trainer.model.load_state_dict(pre_weight)
        apply_pruning(prune_trainer.model, total_prune_rate)
        prune_trainer.run()
        if prune_trainer.early_stopping.best_acc < history_best_acc - acc_decay_tol:
            # 将上次的剪枝作为最终结果
            print(f"acc decay more than {acc_decay_tol}, stop pruning, "
                  f"final prune rate: {pre_total_prune_rate}, "
                  f"final best acc: {pre_best_acc}")
            prune_trainer.model.load_state_dict(pre_weight)
            remove_pruning(prune_trainer.model)
            torch.save(prune_trainer.model.state_dict(), f"prune/best.pt")
            break
        else:
            # 继续剪枝
            print(f"history_best_acc: {history_best_acc}, "
                  f"current best_acc: {prune_trainer.early_stopping.best_acc}, "
                  f"acc decay less than {acc_decay_tol}, continue pruning\n")
            if prune_trainer.early_stopping.best_acc > history_best_acc:
                history_best_acc = prune_trainer.early_stopping.best_acc
            pre_weight = prune_trainer.model.state_dict()
            pre_total_prune_rate = total_prune_rate
            pre_best_acc = prune_trainer.early_stopping.best_acc

3 知识蒸馏

3.1 什么是知识蒸馏

知识蒸馏(Knowledge Distillation)是一种在模型压缩和迁移学习领域非常重要的技术。其核心思想是先训练一个复杂网络模型,然后使用这个复杂网络的输出和数据的真实标签去训练一个更小的网络,从而在保持高精度的同时,显著减少模型的参数量和计算量。其中复制的模型被称为 Teacher 模型,小模型被称为 Student 模型。
image

知识蒸馏采用 Teacher-Student 模式,即使用复杂且大型的模型作为教师(Teacher),辅助结构较为简单的学生模型(Student)的训练。教师模型由于学习能力强,可以将学到的知识迁移给学习能力相对弱的学生模型,从而增强学生模型的泛化能力。复杂但效果好的教师模型不会上线使用,而是作为导师角色,真正部署进行预测任务的是灵活轻巧的学生模型。

知识蒸馏是对模型能力的迁移,根据迁移方法的不同,可以分为以下两大方向:

  • 基于目标蒸馏(也称 Soft-target 蒸馏或 Logits 方法蒸馏):通过教师模型的输出软标签来指导学生模型。
  • 基于特征蒸馏:通过教师模型的中间层特征来指导学生模型。

3.2 知识蒸馏的作用

知识蒸馏的作用可以概况为以下三种:

  • 提升模型精度

    如果对目前的网络模型 A 的精度不是很满意,那么可以先训练一个更高精度的 teacher 模型 B(通常参数量更多,推理延迟时间更长),然后用这个训练好的 teacher 模型 B 对 student 模型 A 进行知识蒸馏,得到一个更高精度的 A 模型。

  • 降低模型时延,压缩网络参数

    如果对目前的网络模型 A 的推理延迟或模型大小不满意,可以先找到一个推理延迟更低,参数量更小的模型 B,通常来讲,这种模型精度也会比较低,然后使用网络模型 A 对这个参数量小的模型 B 进行知识蒸馏,使得该模型 B 的精度接近模型 A 的同时,能达到降低推理延迟和减小模型大小的目的。

  • 标签之间的域迁移

    假如使用狗和猫的数据集训练了一个 teacher 模型 A,使用香蕉和苹果训练了一个 teacher 模型 B,那么就可以用这两个模型同时蒸馏出一个可以识别狗、猫、香蕉以及苹果的模型,将两个不同域的数据集进行集成和迁移。

3.3 知识蒸馏与从头训练的区别

使用教师模型 B 对学生模型 A 进行蒸馏训练,与直接用数据集训练模型 A 相比,主要区别在于知识传递和优化路径。

模型 B 已经在搜索空间中找到最优或次优的参数组合,这些知识通过软标签传递给模型 A,缩小了其搜索空间,使模型 A 不仅从数据中学习,还从教师模型的预测中获取丰富信息,能更快地找到有效参数组合,从而在较短时间内达到较高性能。

相比之下,直接训练模型 A 需要从随机初始化开始,在模型整个搜索空间中查找参数最优解,所以优化难度大,训练时间长,且容易陷入局部最优解。
image

3.4 基于目标蒸馏

目标蒸馏方法中最经典的论文就是来自于 2015 年 Hinton 发表的一篇神作《Distilling the Knowledge in a Neural Network》。在这篇论文中,Hinton 将问题限定在分类问题下,分类问题的共同点是模型最后会有一个 Softmax 层,其输出值对应了相应类别的概率值。在知识蒸馏时,由于我们已经有了一个泛化能力较强的 Teacher 模型,我们在利用 Teacher 模型来蒸馏训练 Student 模型时,可以直接让 Student 模型去学习 Teacher 模型的泛化能力。一个很直白且高效的迁移泛化能力的方法就是:使用 Softmax 层输出的类别的概率来作为 Soft-target 。
image

3.4.1 Hard-target 和 Soft-target

传统的神经网络训练方法是定义一个损失函数,目的是让模型的预测尽可能接近真实值(Hard- target),损失函数的设计旨在最小化神经网络的预测与真实标签之间的差异,从而使得模型能够准确地预测数据的真实性(即对 ground truth 求极大似然)。在基于目标的知识蒸馏中,比较特别的是使用了 Teacher 模型的预测类别概率作为 Soft-target 来指导学生模型的训练。

  • Hard-target:原始数据集标注的 one-hot 标签,除了正标签为 1,其他负标签都是 0。
  • Soft-target:Teacher 模型 Softmax 层输出的类别概率,每个类别都分配了概率,正标签的概率最高。

image

知识蒸馏用 Teacher 模型预测的 Soft-target 来辅助 Hard-target 训练 Student 模型的方式为什么有效呢?Softmax 层的输出,除了正例之外,负标签也带有 Teacher 模型归纳推理的大量信息,比如某些负标签对应的概率远远大于其他负标签,则代表 Teacher 模型在推理时认为该样本与该负标签有一定的相似性。而在传统的训练过程(Hard-target)中,所有负标签都被统一对待。也就是说,知识蒸馏的训练方式使得每个样本给 Student 模型带来的信息量大于传统的训练方式。

如在 MNIST 数据集中做手写体数字识别任务,假设某个输入的 "2" 更加形似 "3",Softmax 的输出值中 "3" 对应的概率会比其他负标签类别高;而另一个 "2" 更加形似 "7",则这个样本分配给 "7" 对应的概率会比其他负标签类别高。这两个 "2" 对应的 Hard-target 的值是相同的,但是它们的 Soft-target 却是不同的,由此我们可见 Soft-target 蕴含着比 Hard-target 更多的信息。

image

在使用 Soft-target 训练时,Student 模型可以很快学习到 Teacher 模型的推理过程;而传统的 Hard-target 的训练方式,所有的负标签都会被平等对待。因此,Soft-target 给 Student模型带来的信息量要大于 Hard-target,并且 Soft-target 分布的熵相对高时,其 Soft-target 蕴含的知识就更丰富。同时,使用 Soft-target 训练时,梯度的方差会更小,训练时可以使用更大的学习率,所需要的样本也更少。这也解释了为什么通过蒸馏的方法训练出的 Student 模型相比使用完全相同的模型结构和训练数据只使用 Hard-target 的训练方法得到的模型,拥有更好的泛化能力。

3.4.2 蒸馏温度

Logits 是深度神经网络中未经 Softmax 处理的输出值,它们代表了模型在每个类别上的得分或概率(未归一化)。在分类问题中,这些 logits 汇总了网络内部对各个类别的信息,每个 logits 表示输入属于对应类别的得分,而不是概率。假设有一个分类任务,最后一个输出层是全连接层,则它的直接输出就是 logits,可以用一个向量 \(z\) 表示:

\[z=\left[z_1, z_2, \ldots, z_n\right] \]

其中:

(1)\(n\) 是总的类别数。

(2)\(z_i\) 是第 \(i\) 个类别的预测得分。

在分类任务重,我们通常不直接使用 logits,而是用 Softmax 函数将 logits 转换为概率分布,使得这些概率之和为 1。logits 经 Softmax 函数处理后各类的概率如下:

\[q_i=\frac{\exp \left(z_i\right)}{\sum_{j=1}^n \exp \left(z_j\right)} \]

其中:

(1)\(n\) 是总的类别数。

(2)\(z_i\) 是第 \(i\) 个类别的预测得分。

(3)\(q_i\) 是第 \(i\) 个类别的预测概率。

但是直接使用 Softmax 层的输出值 \(q\) 作为 Soft-target 会带来一个问题,当 Softmax 输出的概率分布熵较小时,某些类别的概率会接近 0 或 1,这可能导致得到的 Soft-target 近似于 Hard-target,从而失去了 Soft-target 的优势和信息。一种常见的解决方法是引入蒸馏温度参数 𝑇 ,通过缩放 logits 来调节 Softmax 输出的概率分布。温度调节可以使 Softmax 输出的概率分布更加平滑,减少极端概率值的出现,从而提高 Soft-target 的有效性。下面的公式是加了蒸馏温度之后的 Softmax 函数:

\[q^T_i=\frac{\exp \left(z_i / T\right)}{\sum_{j=1}^n \exp \left(z_j / T\right)} \]

其中:

(1)\(n\) 是总的类别数。

(2)\(z_i\) 是第 \(i\) 个类别的预测得分。

(3)\(T\) 是蒸馏温度参数,通常是高温蒸馏,即 \(T>1\)

(4)\(q_i^T\) 是经温度 \(T\) 修正后的第 \(i\) 个类别的预测概率。

温度参数 \(T\) 有这样几个特点:

  • \(T=1\) 时,Softmax 函数为标准的 Softmax 函数,即没有温度调节。
  • \(T<1\) 时,Softmax 函数输出的概率差异变大,即概率分布比标准 Softmax 函数的输出更不均匀。
  • \(T\rightarrow 0\) 时,Softmax 函数输出的概率分布接近 Hard-target,即输出概率接近 0 或 1。
  • \(T>1\) 时,Softmax 函数输出的概率差异变小,即概率分布比标准 Softmax 函数的输出更均匀。
  • \(T\rightarrow +\infty\) 时,Softmax 函数输出的概率接近接近均匀分布,即每个类别的概率相等。

image

温度的高低改变的是 Student 模型训练过程中对负标签的关注程度。当温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Student 模型会相对更多地关注到负标签。

实际上,负标签中包含一定的信息,尤其是那些负标签概率值显著高于平均值的负标签。但由于 Teacher 模型的训练过程决定了负标签部分概率值都比较小,并且负标签的值越低,其信息就越不可靠。因此温度的选取需要进行实际实验的比较,本质上就是在下面两种情况之中取舍:

  • 当想从负标签中学到一些信息量的时候,温度 \(T\) 应调高一些;
  • 当想减少负标签的干扰的时候,温度 \(T\) 应调低一些;

总的来说,\(T\) 的选择和 Student 模型的大小有关,Student 模型参数量比较小的时候,相对比较低的温度就可以了。因为参数量小的模型不能学到所有 Teacher 模型的知识,所以可以适当忽略掉一些负标签的信息。

在实际 \(T\) 可以初步设置为 2 或 3 然后逐步增加,以确定最优的蒸馏温度 \(T\)

在整个知识蒸馏过程中,在训练阶段,我们先让温度 \(T\) 升高。在测试阶段,再恢复正常温度(\(T=1\)),从而将原模型中的知识提取出来,因此将其称为是蒸馏。

3.2.3 目标蒸馏的实现步骤

image
目标蒸馏训练的具体方法如上图所示,主要包括以下几个步骤:

  1. 选择 Teacher 模型和 Student 模型

    Teacher 模型选择一个训练良好且性能优越的复杂模型,通常是深度且大规模的神经网络。而 Student 模型选择一个较为简单、参数较少的模型,以便在硬件受限的环境中部署。

  2. 训练 Teacher 模型

    按照标准训练流程训练 Teacher模型,使其在给定的任务上有较高的准确率。

  3. 定义用于 Student 模型训练的损失函数

    • 计算 Soft Loss
      Soft Loss 是通过计算学生模型 Soft-prediction(高温 \(T_{high}\) 下通过 Softmax 函数处理的模型输出) 与 Teacher 模型的 Soft-target 之间的 KL 散度得到:

      \[L_{\text{soft}} = \text{KL}(q_t^{T_{high}} \parallel q_s^{T_{high}}) = \sum_{i=1}^{n} q_{t,i}^{T_{high}} \log \left(\frac{q_{t,i}^{T_{high}}}{q_{s,i}^{T_{high}}}\right) \]

      其中:
      (1)\(n\) 是总的类别数。

      (2)\(q_{t,i}^{T_{high}}\) 是 Teacher 模型在高温 \(T_{high}\) 下通过 Softmax 函数处理得到的第 \(i\) 类的概率。

      (3)\(q_{s,i}^{T_{high}}\) 是 Student 模型在高温 \(T_{high}\) 下通过 Softmax 函数处理得到的第 \(i\) 类的概率。

      较高的温度 \(T\) 会导致 logits 被缩小,从而使得 Softmax 函数的输出更平滑。但在反向传播过程中,这种缩放会影响梯度的计算,使得梯度被过度缩小。因此,在代码实现时,计算的 KL 散度需要乘以 \(T^2\),用于梯度补偿。

    • 计算 Hard Loss

      Hard Loss 是通过计算学生模型 Hard-prediction(标准 Softmax 函数处理的模型输出)输出与真实标签 Hard-target 之间的交叉熵损失得到:

      \[L_{\text{hard}} = \text{CE}(y, p_s^1) = - \sum_{i=1}^{K} y_i \log(p_{s,i}^1) \]

      其中:

      (1)\(n\) 是总的类别数。

      (2)\(y_i\) 是真实的标签 Hard-target 中第 \(i\) 类的值。

      (3)\(q_{s,i}^1\) 是学生模型在标准 Softmax 函数(\(T=1\))下处理得到的第 \(i\) 类的概率。

    • 综合损失

      高温蒸馏过程的损失函数由 Soft Loss 和 Hard Loss 加权得到。常用损失函数公式如下:

      \[L=(1-\alpha)\cdot L_{soft}+\alpha\cdot L_{hard} \]

      其中 \(\alpha\) 是一个权重参数,用于平衡软目标损失和硬目标损失的贡献。

      真正的损失是由 Soft Loss 和 Hard Loss 综合得到的原因是:Teacher 模型也有一定的错误率,使用真实标签可以有效降低错误信息被传播给 Student 模型的可能性。例如,老师虽然学识远远超过学生,但是他仍然有出错的可能,而这时候如果学生在老师的教授之外,可以同时参考到标准答案,就可以有效地降低被老师偶尔的错误“带偏”的可能性。

      \(\alpha\) 的设置取决于你希望学生模型更倾向于学习教师模型输出还是真实标签。通常,\(\alpha\) 的取值应通过实验确定,但可以根据以下指导原则进行初步设置:

      (1)如果你希望学生模型在学习教师模型的同时也能关注真实标签, \(\alpha\) 值可以设置为 0.5,这种方式可以平衡交叉熵损失和蒸馏损失。

      (2)如果 Teacher 模型的性能较好,则可以选择一个较小的 \(\alpha\) 值(例如 0.1 或 0.3),让 Student 模型更多地学习 Teacher 模型的知识。

      (3)如果 Teacher 模型性能不够好,则可以选择一个较大的 \(\alpha\) 值(例如 0.7 或 0.9),让 Student 模型更多地学习真实标签的知识。

      (4)在知识蒸馏的过程中, \(\alpha\) 也可以随训练的进程逐渐变化。这种策略可以帮助学生模型在训练初期更多地学习教师模型的知识,然后逐渐过渡到更多关注真实标签。在训练的早期阶段,学生模型可能还不够强大,更多依赖教师模型的指导,通过设定较小的 \(\alpha\) 值让学生模型更多地学习教师模型的知识。随着训练的进行,学生模型变得更加成熟,可以更多地依赖真实标签,通过逐渐增大 \(\alpha\) 值,使学生模型更多地关注真实标签的损失。

      在实际中, \(\alpha\) 可能需要基于以上方案分别进行训练,以确定适合当前任务的最优 \(\alpha\) 值。

    • 蒸馏损失的 PyTorch 代码实现

      import torch.nn.functional as F
      
      def distillation_loss(student_logits, teacher_logits, targets, alpha, T):
          soft_loss = F.kl_div(
         	 F.log_softmax(student_logits / T, dim=1),
         	 F.softmax(teacher_logits / T, dim=1),
         	 reduction='batchmean'
          ) * T ** 2
          hard_loss = F.cross_entropy(student_logits, targets)
          return alpha * hard_loss + (1 - alpha) * soft_loss
      
  4. 训练 Student 模型

    使用设计好的蒸馏损失函数正常训练 Student 模型。需要注意的是,在测试阶段需要将蒸馏温度恢复成 1,确保学生模型输出与标准 Softmax 输出一致,此时再进行模型评估和测试。

3.4.4 Matching Logits

Matching Logits 是一种特殊的目标蒸馏方法,它在蒸馏过程中不直接使用 Softmax 函数对 Teacher 和 Student 模型的输出进行处理。而是直接比较它们的 logits(即未经过 Softmax 处理的输出值)。这种方法的目标是通过尽可能匹配教师和学生模型在每个类别上的 logits,来传递知识。

具体来说,Matching Logits 的损失函数通常可以定义为 logits 的平方误差或者其他形式的距离度量。一种常见的 Matching Logits 损失函数形式可以如下表示:

\[L=\sum_{i=1}^n\left(z_{t, i}-z_{s, i}\right)^2 \]

其中:

(1)\(n\) 是总的类别数。

(2)\(z_{t,i}\) 是 Teacher 模型对第 \(i\) 个类别的 logits 值。

(3)\(z_{s,i}\) 是 Student 模型对第 \(i\) 个类别的 logits 值。

Matching Logits 方法的优点在于,它避免了 Softmax 函数引入的一些问题,如极端概率值和信息丢失,而直接关注于 logits 的匹配,更加精确地传递 Teacher 模型的知识。Matching Logits 方法的缺点是依赖模型结构,如果两个模型的内部表示和学习机制大相径庭,直接匹配 logits 也可能无法有效传递教师模型的知识。

此时蒸馏损失的 PyTorch 代码实现:

import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits):
    return F.mse_loss(student_logits, teacher_logits)

3.5 基于特征蒸馏

另外一种知识蒸馏思路是特征蒸馏方法,这种方法最早在论文《FitNets: Hints for Thin Deep Nets》中被提出,其核心思想是让学生模型不仅学习教师模型的最终决策(如概率分布或 logits),而且更重要的是学习教师模型在处理输入数据过程中的中间层特征表示。
image

这篇论文首先提出一个案例,既宽又深的模型通常需要大量的乘法运算,导致对内存和计算资源的高需求,从而限制了它们在现实应用中的广泛使用。为了解决这个问题,可以通过知识蒸馏将知识从复杂的模型转移到参数较少的简单模型,但现有的知识蒸馏技术主要关注 Student 网络与 Teacher 网络具有相同或更小的参数,而未关注特征学习的基本层面——深度。因此,该篇论文主要针对 Hinton 提出的知识蒸馏法进行扩展,允许 Student 模型可以比 Teacher 模型更深更窄,使用 Teacher 网络的输出和中间层的特征作为提示,改进训练过程和 Student 网络的性能。

论文中将这种具有比 Teacher 网络更多的层但每层具有较少神经元数量的 Student 网络称为 "thin deep network"。

3.6 参考文献

主要参考:

目标蒸馏论文:

特征蒸馏论文:

3.7 知识蒸馏代码实现

3.7.1 导入必要的依赖

import os
import time
import numpy as np
import tqdm
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torchvision import transforms, models
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

3.7.2 定义训练早停类

class EarlyStopping:
    def __init__(self, patience=100, min_delta=0):
        if patience <= 0:
            raise ValueError('patience must be a positive integer')
        if min_delta < 0:
            raise ValueError('min_delta must be a positive number')

        self.patience = patience
        self.current_patience = patience
        self.min_delta = min_delta
        self.early_stop = False
        self.best_acc = 0
        self.best_loss = None
        self.best_epoch = None

    def __call__(self, val_loss, val_acc, epoch):
        if val_acc > self.best_acc + self.min_delta:
            self.current_patience = self.patience
            print(f'Epoch {epoch}: '
                  f"best acc improved from {self.best_acc:.4f} to {val_acc:.4f}, "
                  f'current patience reset to {self.patience}')
            self.best_acc = val_acc
            self.best_loss = val_loss
            self.best_epoch = epoch
        else:
            self.current_patience -= 1
            if self.current_patience <= 0:
                self.early_stop = True
                print(f"{self.patience} patience exhausted, early stopping! "
                      f"Best acc: {self.best_acc:.4f}, "
                      f"and best loss: {self.best_loss:.4f},"
                      f"appear at epoch {self.best_epoch}")
            else:
                print(f'Epoch {epoch}: '
                      f'best acc did not improve from {self.best_acc:.4f}, '
                      f'current patience: {self.current_patience}/{self.patience}')
        return self.early_stop

3.7.3 定义训练基础类

class BaseTrainer:
    def __init__(self,
                 num_epochs,
                 batch_size,
                 save_dir="weight",
                 device="cuda",
                 num_workers=0,
                 patience=10,
                 min_delta=0,
                 ):
        # 设置随机种子
        seed = 42
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        # 设置设备
        if device == "cuda" and torch.cuda.is_available():
            print("Using GPU for training")
            self.device = torch.device("cuda")
        else:
            if not torch.cuda.is_available():
                print("CUDA is not available, using CPU instead")
            else:
                print("Using CPU for training")
            self.device = torch.device("cpu")

        # 定义输出位置
        self.save_dir = save_dir
        if not os.path.exists(self.save_dir):  # 如果不存在该目录,则创建该目录
            os.makedirs(self.save_dir)

        # 加载模型
        self.model = self.load_model()
        self.model.to(self.device)

        # 定义训练参数
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.optimizer, self.scheduler = self.load_optimizer_and_scheduler()
        self.early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
        self.history = []
        self.train_loader, self.test_loader = self.get_data_loader()

    def get_data_loader(self):
        # 计算训练集的mean和std
        dataset = CIFAR10(root='data', train=True, download=True, transform=transforms.ToTensor())
        data_loader = DataLoader(dataset, batch_size=len(dataset), num_workers=self.num_workers)
        data = next(iter(data_loader))
        mean = torch.mean(data[0], dim=(0, 2, 3)).tolist()  # 计算均值
        std = torch.std(data[0], dim=(0, 2, 3)).tolist()  # 计算标准差

        # 定义数据预处理
        train_transform = transforms.Compose([  # 定义训练集的预处理
            transforms.RandomCrop(32, padding=4),  # 随机裁剪,增强模型对图像位置的鲁棒性
            transforms.RandomHorizontalFlip(),  # 随机水平翻转,增加数据的多样性
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动,增强模型对颜色变化的鲁棒性
            transforms.RandomRotation(15),  # 随机旋转,增加模型对旋转变换的鲁棒性
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        test_transform = transforms.Compose([  # 定义测试集的预处理
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        # 创建数据集和数据加载器
        train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transform)
        test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transform)
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return train_loader, test_loader

    def load_model(self):
        num_classes = 10
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        return model

    def load_optimizer_and_scheduler(self):
        lr = 0.01
        momentum = 0.9
        weight_decay = 1e-4
        optimizer = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.num_epochs)
        print(f"Optimizer: {optimizer.__class__.__name__}, lr={lr}, momentum={momentum}, weight_decay={weight_decay}")
        print(f"Scheduler: {scheduler.__class__.__name__}")
        print(f"You can change the optimizer and scheduler in the load_optimizer_and_scheduler method!")
        return optimizer, scheduler

    def save_model(self, filename):
        torch.save(self.model.state_dict(), os.path.join(self.save_dir, f"{filename}.pt"))

    def save_graph(self, epoch):
        # 展示训练过程
        history_np = np.array(self.history)
        plt.figure(figsize=(10, 10))  # 创建一个10*10的画布绘制损失曲线图
        plt.plot(np.arange(1, epoch + 1), (history_np[:, [0, 2]]))  # 画出训练集损失和验证集损失
        plt.legend(['Train Loss', 'Test Loss'])  # 显示图例
        plt.xlabel('Epoch')  # 设置x轴标签
        plt.ylabel('Loss')  # 设置y轴标签
        x_ticks = np.arange(0, epoch + 1, step=10)
        x_ticks[0] = 1  # epoch是从1开始的
        plt.xticks(x_ticks)  # 设置坐标轴刻度
        plt.yticks(np.arange(0, 2.05, 0.1))  # 设置坐标轴刻度
        plt.grid()  # 画出网格
        plt.gca().set_ylim([0, 2])  # 设置y轴范围
        best_test_loss_idx = np.argmin(history_np[:, 2])
        best_test_loss = history_np[best_test_loss_idx, 2]
        plt.text(
            0.1,
            1.05,
            f'Best test loss: {best_test_loss:.4f}, epoch: {best_test_loss_idx + 1}',
            ha='center',
            va='center',
            transform=plt.gca().transAxes
        )
        plt.savefig(os.path.join(self.save_dir, 'loss_curve.png'))  # 保存图片
        plt.close()  # 关闭画布

        plt.figure(figsize=(10, 10))  # 创建一个10*10的画布绘制准确率曲线图
        plt.plot(np.arange(1, epoch + 1), (history_np[:, [1, 3]]))  # 画出训练集准确率和验证集准确率
        plt.legend(['Train Accuracy', 'Test Accuracy'])  # 显示图例
        plt.xlabel('Epoch')  # 设置x轴标签
        plt.ylabel('Accuracy')  # 设置y轴标签
        x_ticks = np.arange(0, epoch + 1, step=10)
        x_ticks[0] = 1  # epoch是从1开始的
        plt.xticks(x_ticks)  # 设置坐标轴刻度
        plt.yticks(np.arange(0, 1.05, 0.05))  # 设置坐标轴刻度
        plt.grid()  # 画出网格
        best_test_acc_idx = np.argmax(history_np[:, 3])
        best_test_acc = history_np[best_test_acc_idx, 3]
        plt.text(
            0.1,
            1.05,
            f'Best test acc: {best_test_acc:.4f}, epoch: {best_test_acc_idx + 1}',
            ha='center',
            va='center',
            transform=plt.gca().transAxes
        )
        plt.savefig(os.path.join(self.save_dir, 'acc_curve.png'))  # 保存图片
        plt.close()  # 关闭画布

    def train(self, epoch):
        train_total_loss = 0
        train_total_acc = 0
        self.model.train()
        for images, labels in tqdm.tqdm(self.train_loader, total=len(self.train_loader),
                                        desc=f'Training Epoch {epoch}/{self.num_epochs}', delay=0.1):
            images = images.to(self.device)
            labels = labels.to(self.device)
            outputs = self.model(images)
            loss = F.cross_entropy(outputs, labels)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            train_total_loss += loss.item()
            train_total_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
        train_avg_loss = train_total_loss / len(self.train_loader)
        train_avg_acc = train_total_acc / len(self.train_loader.dataset)
        print(f'Epoch {epoch}/{self.num_epochs}, Train Loss: {train_avg_loss:.4f}, Train Acc: {train_avg_acc:.4f}')
        return train_avg_loss, train_avg_acc

    def test(self, epoch):
        test_total_loss = 0
        test_total_acc = 0
        self.model.eval()
        with torch.no_grad():
            for images, labels in tqdm.tqdm(self.test_loader, total=len(self.test_loader),
                                            desc=f'Testing Epoch {epoch}/{self.num_epochs}', delay=0.1):
                images = images.to(self.device)
                labels = labels.to(self.device)
                outputs = self.model(images)
                loss = F.cross_entropy(outputs, labels)
                test_total_loss += loss.item()
                test_total_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
        test_avg_loss = test_total_loss / len(self.test_loader)
        test_avg_acc = test_total_acc / len(self.test_loader.dataset)
        print(f'Epoch {epoch}/{self.num_epochs}, Test Loss: {test_avg_loss:.4f}, Test Acc: {test_avg_acc:.4f}')
        return test_avg_loss, test_avg_acc

    def run(self):
        start_time = time.time()  # 记录开始时间
        if self.history:
            raise Exception("This object has already run, please create a new object")
        for epoch in range(1, 1 + self.num_epochs):
            print(f"Current lr: {self.scheduler.get_last_lr()[0]:.6f}")
            train_loss, train_acc = self.train(epoch)
            test_loss, test_acc = self.test(epoch)
            self.history.append((train_loss, train_acc, test_loss, test_acc))
            self.scheduler.step()
            self.save_graph(epoch)
            self.save_model("last")
            if self.early_stopping(test_loss, test_acc, epoch):
                break
            if self.early_stopping.current_patience == self.early_stopping.patience:
                self.save_model("best")
            print()
        print(f"Training finished. "
              f"Best acc: {self.early_stopping.best_acc:.4f} "
              f"and best loss: {self.early_stopping.best_loss:.4f} "
              f"appear at epoch {self.early_stopping.best_epoch}")
        end_time = time.time()  # 记录结束时间
        period = round(end_time - start_time)
        print(f'All time: {period} s')
        print(f"All Time: {period // 60:d} min {period % 60:d} s")

3.7.4 定义 Teacher 模型训练类

class TeacherTrainer(BaseTrainer):
    def __init__(self,
                 num_epochs,
                 batch_size,
                 save_dir="weight",
                 device="cuda",
                 num_workers=0,
                 patience=10,
                 min_delta=0,
                 ):
        super().__init__(num_epochs, batch_size, save_dir, device, num_workers, patience, min_delta)

    def get_data_loader(self):
        # 计算训练集的mean和std
        dataset = CIFAR10(root='data', train=True, download=True, transform=transforms.ToTensor())
        data_loader = DataLoader(dataset, batch_size=len(dataset), num_workers=self.num_workers)
        data = next(iter(data_loader))
        mean = torch.mean(data[0], dim=(0, 2, 3)).tolist()  # 计算均值
        std = torch.std(data[0], dim=(0, 2, 3)).tolist()  # 计算标准差

        # 定义数据预处理
        train_transform = transforms.Compose([  # 定义训练集的预处理
            transforms.RandomCrop(32, padding=4),  # 随机裁剪,增强模型对图像位置的鲁棒性
            transforms.RandomHorizontalFlip(),  # 随机水平翻转,增加数据的多样性
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动,增强模型对颜色变化的鲁棒性
            transforms.RandomRotation(15),  # 随机旋转,增加模型对旋转变换的鲁棒性
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        test_transform = transforms.Compose([  # 定义测试集的预处理
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        # 创建数据集和数据加载器
        train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transform)
        test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transform)
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return train_loader, test_loader

    def load_model(self):
        num_classes = 10
        model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        return model

3.7.5 定义 Student 基准模型训练类

class BenchmarkTrainer(BaseTrainer):
    def __init__(self,
                 num_epochs,
                 batch_size,
                 save_dir="weight",
                 device="cuda",
                 num_workers=0,
                 patience=10,
                 min_delta=0,
                 ):
        super().__init__(num_epochs, batch_size, save_dir, device, num_workers, patience, min_delta)

    def get_data_loader(self):
        # 计算训练集的mean和std
        dataset = CIFAR10(root='data', train=True, download=True, transform=transforms.ToTensor())
        data_loader = DataLoader(dataset, batch_size=len(dataset), num_workers=self.num_workers)
        data = next(iter(data_loader))
        mean = torch.mean(data[0], dim=(0, 2, 3)).tolist()  # 计算均值
        std = torch.std(data[0], dim=(0, 2, 3)).tolist()  # 计算标准差

        # 定义数据预处理
        train_transform = transforms.Compose([  # 定义训练集的预处理
            transforms.RandomCrop(32, padding=4),  # 随机裁剪,增强模型对图像位置的鲁棒性
            transforms.RandomHorizontalFlip(),  # 随机水平翻转,增加数据的多样性
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动,增强模型对颜色变化的鲁棒性
            transforms.RandomRotation(15),  # 随机旋转,增加模型对旋转变换的鲁棒性
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        test_transform = transforms.Compose([  # 定义测试集的预处理
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        # 创建数据集和数据加载器
        train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transform)
        test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transform)
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return train_loader, test_loader

    def load_model(self):
        num_classes = 10
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        return model

3.7.6 定义蒸馏损失类

class KDLoss:
    def __init__(self, num_epoch, alpha_scope=(0.1, 0.9), temperature=3):
        self.num_epoch = num_epoch
        self.alpha = alpha_scope
        self.temperature = temperature

    def _cal_alpha(self, epoch):
        return self.alpha[0] + (self.alpha[1] - self.alpha[0]) * epoch / self.num_epoch

    def __call__(self, student_logits, teacher_logits, targets, epoch):
        alpha = self._cal_alpha(epoch)
        hard_loss = F.cross_entropy(student_logits, targets)
        soft_loss = F.kl_div(F.log_softmax(student_logits / self.temperature, dim=1),
                             F.softmax(teacher_logits / self.temperature, dim=1),
                             reduction='batchmean') * self.temperature ** 2
        return alpha * hard_loss + (1 - alpha) * soft_loss

3.7.7 定义 Student 模型训练类

class StudentTrainer(BaseTrainer):
    def __init__(self,
                 num_epoch,
                 batch_size,
                 save_dir="weight",
                 device="cuda",
                 num_workers=0,
                 patience=10,
                 min_delta=0,
                 alpha_scope=(0.1, 0.9),
                 temperature=3,
                 teacher_weight_path="teacher/best.pt"
                 ):
        super().__init__(num_epoch, batch_size, save_dir, device, num_workers, patience, min_delta)
        self.teacher_weight_path = teacher_weight_path
        self.kl_loss = KDLoss(num_epoch, alpha_scope, temperature)
        self.teacher_model = self.load_teacher_model()
        self.teacher_model.to(self.device)

    def get_data_loader(self):
        # 计算训练集的mean和std
        dataset = CIFAR10(root='data', train=True, download=True, transform=transforms.ToTensor())
        data_loader = DataLoader(dataset, batch_size=len(dataset), num_workers=self.num_workers)
        data = next(iter(data_loader))
        mean = torch.mean(data[0], dim=(0, 2, 3)).tolist()  # 计算均值
        std = torch.std(data[0], dim=(0, 2, 3)).tolist()  # 计算标准差

        # 定义数据预处理
        train_transform = transforms.Compose([  # 定义训练集的预处理
            transforms.RandomCrop(32, padding=4),  # 随机裁剪,增强模型对图像位置的鲁棒性
            transforms.RandomHorizontalFlip(),  # 随机水平翻转,增加数据的多样性
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动,增强模型对颜色变化的鲁棒性
            transforms.RandomRotation(15),  # 随机旋转,增加模型对旋转变换的鲁棒性
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        test_transform = transforms.Compose([  # 定义测试集的预处理
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        # 创建数据集和数据加载器
        train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transform)
        test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transform)
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return train_loader, test_loader

    def load_teacher_model(self):
        if not os.path.exists(self.teacher_weight_path):
            raise FileNotFoundError("teacher weight not found")
        num_classes = 10
        model = models.resnet152()
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        model.load_state_dict(torch.load(self.teacher_weight_path))
        return model

    def load_model(self):
        num_classes = 10
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        return model

    def train(self, epoch):
        train_total_loss = 0
        train_total_acc = 0
        self.model.train()
        for images, labels in tqdm.tqdm(self.train_loader, total=len(self.train_loader),
                                        desc=f'Training Epoch {epoch}/{self.num_epochs}', delay=0.1):
            images = images.to(self.device)
            labels = labels.to(self.device)
            outputs = self.model(images)
            teacher_outputs = self.teacher_model(images)
            loss = self.kl_loss(outputs, teacher_outputs, labels, epoch)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            train_total_loss += loss.item()
            train_total_acc += (outputs.argmax(dim=1) == labels).float().sum().item()
        train_avg_loss = train_total_loss / len(self.train_loader)
        train_avg_acc = train_total_acc / len(self.train_loader.dataset)
        print(f'Epoch {epoch}/{self.num_epochs}, Train Loss: {train_avg_loss:.4f}, Train Acc: {train_avg_acc:.4f}')
        return train_avg_loss, train_avg_acc

3.7.8 开始训练

if __name__ == '__main__':
    teacher_trainer = TeacherTrainer(
        num_epochs=200,
        batch_size=512,
        save_dir="teacher",
        num_workers=3,
        patience=200,
    )
    teacher_trainer.run()

    benchmark_trainer = BenchmarkTrainer(
        num_epochs=200,
        batch_size=512,
        save_dir="benchmark",
        num_workers=3,
        patience=200,
    )
    benchmark_trainer.run()

    student_trainer = StudentTrainer(
        num_epoch=100,
        batch_size=512,
        save_dir="student",
        num_workers=3,
        patience=100,
        teacher_weight_path="teacher/best.pt"
    )
    student_trainer.run()

4 量化

posted @ 2024-07-04 18:35  gokamisama  阅读(235)  评论(0)    收藏  举报