人工智能概念:常用的模型压缩技术(剪枝、量化、知识蒸馏) - 详解
文章目录
一、模型压缩概述
1.1 什么是模型压缩?
模型压缩是一类通过减少模型参数数量、降低计算复杂度,从而在资源受限设备上高效部署深度学习模型的技术。其核心目标是在模型性能损失最小化的前提下,显著减小模型体积、降低内存占用、提升推理速度,以适应移动端、嵌入式设备等资源受限场景的需求。
1.2 为什么需要模型压缩?
随着Transformer等大模型的兴起,模型参数规模呈指数级增长。例如,原始BERT-base模型参数量约110M,推理时不仅占用大量内存,还要求较高的计算资源,难以直接部署在手机、摄像头等边缘设备上。此外,大模型的高推理延迟也无法满足实时性要求较高的业务场景(如实时推荐、语音助手)。
模型压缩的意义在于:
- 降低存储成本:减小模型文件大小,节省存储空间。
- 提升推理速度:减少计算量,降低延迟,满足实时性需求。
- 降低部署门槛:使模型能够在算力有限的边缘设备上运行。
- 减少能耗:降低推理过程中的能量消耗,适合移动设备。
1.3 四种主流模型压缩技术
目前,业界常用的模型压缩技术主要有四类:
| 技术名称 | 核心思想 | 特点 |
|---|---|---|
| 剪枝(Pruning) | 移除模型中冗余的参数(如权重值接近0的连接),保留关键参数。 | 可分为结构化剪枝(移除整个通道/层)和非结构化剪枝(移除单个权重)。 |
| 量化(Quantization) | 用低精度数据类型(如int8)替代高精度类型(如float32)表示权重和激活值。 | 模型体积缩小4-8倍,推理速度提升2-4倍,实现简单。 |
| 知识蒸馏(Knowledge Distillation) | 让小模型(学生)学习大模型(教师)的“知识”(如软标签),模仿其行为。 | 保留大模型性能的同时,显著减小模型规模,适用于复杂模型压缩。 |
| 低秩因式分解(Low-rank Factorization) | 将高维权重矩阵分解为多个低维矩阵的乘积,减少参数数量。 | 适合线性层、卷积层等矩阵运算密集的模块,压缩率较高但实现较麻烦。 |
二、模型量化:用低精度换高效能

2.1 量化的数学原理
量化的核心是将高精度浮点数(如float32)映射到低精度整数(如int8),核心公式涉及缩放因子(scale) 和零点(zero_point) 的计算。
基本定义
- 设浮点数范围为[ x min , x max ] [x_{\text{min}}, x_{\text{max}}][xmin,xmax],对应整数范围为[ q min , q max ] [q_{\text{min}}, q_{\text{max}}][qmin,qmax](如int8的[ − 128 , 127 ] [-128, 127][−128,127])。
- 缩放因子 s ss:控制浮点数到整数的比例映射。
- 零点 z zz:确保映射的偏移量(使0附近的浮点数能准确映射)。
核心公式
s = x max − x min q max − q min (1) s = \frac{x_{\text{max}} - x_{\text{min}}}{q_{\text{max}} - q_{\text{min}}} \tag{1}s=qmax−qminxmax−xmin(1)
z = q min − round ( x min s ) (2) z = q_{\text{min}} - \text{round}\left(\frac{x_{\text{min}}}{s}\right) \tag{2}z=qmin−round(sxmin)(2)
q = clip ( round ( x s + z ) , q min , q max ) (3) q = \text{clip}\left(\text{round}\left(\frac{x}{s} + z\right), q_{\text{min}}, q_{\text{max}}\right) \tag{3}q=clip(round(sx+z),qmin,qmax)(3)- 式(1):计算缩放因子,将浮点数范围映射到整数范围。
- 式(2):计算零点,确保x min x_{\text{min}}xmin能映射到q min q_{\text{min}}qmin。
- 式(3):将浮点数x xx量化为整数q qq,并裁剪到整数范围内。
反量化公式(推理时还原)
x recon = s ⋅ ( q − z ) (4) x_{\text{recon}} = s \cdot (q - z) \tag{4}xrecon=s⋅(q−z)(4)
2.2 量化计算示例
以float32到int8的量化为例,假设某层权重的浮点数范围为[ − 1.2 , 3.6 ] [-1.2, 3.6][−1.2,3.6],计算过程如下:
步骤1:确定范围
- 浮点数:x min = − 1.2 x_{\text{min}} = -1.2xmin=−1.2,x max = 3.6 x_{\text{max}} = 3.6xmax=3.6
- int8整数:q min = − 128 q_{\text{min}} = -128qmin=−128,q max = 127 q_{\text{max}} = 127qmax=127,范围长度 127 − ( − 128 ) = 255 127 - (-128) = 255127−(−128)=255
步骤2:计算缩放因子s ss
s = 3.6 − ( − 1.2 ) 255 = 4.8 255 ≈ 0.0188 s = \frac{3.6 - (-1.2)}{255} = \frac{4.8}{255} \approx 0.0188s=2553.6−(−1.2)=2554.8≈0.0188
步骤3:计算零点z zz
z = − 128 − round ( − 1.2 0.0188 ) = − 128 − round ( − 63.83 ) = − 128 + 64 = − 64 z = -128 - \text{round}\left(\frac{-1.2}{0.0188}\right) = -128 - \text{round}(-63.83) = -128 + 64 = -64z=−128−round(0.0188−1.2)=−128−round(−63.83)=−128+64=−64
步骤4:量化单个浮点数
例如量化 x = 0.5 x = 0.5x=0.5:
q = round ( 0.5 0.0188 + ( − 64 ) ) = round ( 26.59 − 64 ) = round ( − 37.41 ) = − 37 q = \text{round}\left(\frac{0.5}{0.0188} + (-64)\right) = \text{round}(26.59 - 64) = \text{round}(-37.41) = -37q=round(0.01880.5+(−64))=round(26.59−64)=round(−37.41)=−37
- 量化结果:q = − 37 q = -37q=−37(在int8范围内)。
步骤5:反量化验证
x recon = 0.0188 × ( − 37 − ( − 64 ) ) = 0.0188 × 27 ≈ 0.5076 ≈ 0.5 x_{\text{recon}} = 0.0188 \times (-37 - (-64)) = 0.0188 \times 27 \approx 0.5076 \approx 0.5xrecon=0.0188×(−37−(−64))=0.0188×27≈0.5076≈0.5
- 误差:∣ 0.5076 − 0.5 ∣ = 0.0076 |0.5076 - 0.5| = 0.0076∣0.5076−0.5∣=0.0076,精度损失较小。
2.3 量化相关API详解
- PyTorch量化API
| API名称 | 功能描述 | 关键参数说明 | 适用场景 |
|---|---|---|---|
torch.quantization.quantize_dynamic | 动态量化模型,推理时实时计算scale和zero_point | - model:待量化模型- qconfig_spec:指定需量化的层类型(如{torch.nn.Linear})- dtype:目标数据类型(如torch.qint8) | 飞快部署、动态输入场景 |
torch.quantization.prepare | 为静态量化准备模型(插入观测器) | - model:待准备模型- qconfig:量化配置(如torch.quantization.get_default_qconfig('fbgemm')) | 静态量化(需校准数据) |
torch.quantization.convert | 将准备好的模型转换为量化模型 | - model:经prepare处理的模型 | 静态量化(精度更高) |
torch.quantization.QConfig | 定义量化配置(如激活和权重的量化方式) | - activation:激活量化方式(如FakeQuantize.with_args(observer=MovingAverageMinMaxObserver))- weight:权重量化方式 | 自定义量化策略 |
- TensorFlow量化API
| API名称 | 功能描述 | 关键参数说明 |
|---|---|---|
tf.quantization.quantize | 对张量进行量化(帮助动态范围量化) | - input:待量化张量- min_range/max_range:输入范围- T:目标类型(如tf.int8) |
tf.keras.layers.experimental.QuantizationAwareTraining | 量化感知训练(模拟量化过程,提升量化后精度) | - input_shape:输入形状- num_bits:量化位数 |
- ONNX Runtime量化API
| API名称 | 功能描述 | 关键参数说明 |
|---|---|---|
onnxruntime.quantization.quantize_dynamic | 动态量化ONNX模型 | - input_model:输入ONNX模型路径- output_model:输出量化模型路径- op_types_to_quantize:需量化的算子类型(如['MatMul', 'Add']) |
- 量化注意事项
- 动态量化适合CPU端部署,GPU量化建议应用TensorRT的INT8校准工具。
- 量化对模型精度的影响与任务相关:图像分类通常比目标检测更耐量化,文本分类比NER更耐量化。
- 混合精度量化(如部分层用float16,部分用int8)可在精度和速度间取得更好平衡。
三、知识蒸馏:让小模型学会大模型的“智慧”

3.1 知识蒸馏的数学原理
知识蒸馏的核心是借助KL散度衡量学生模型与教师模型的输出差异,结合硬标签损失优化学生模型。
软标签生成
教师模型的logits经过温度T TT调整后生成软标签:
p i = exp ( z i / T ) ∑ j exp ( z j / T ) (5) p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} \tag{5}pi=∑jexp(zj/T)exp(zi/T)(5)- z i z_izi:教师模型对第i ii类的logits输出。
- T TT:温度参数(T > 1 T>1T>1使分布更平滑,保留更多知识)。
KL散度损失(软标签损失)
衡量学生软标签q qq与教师软标签p pp的差异:
L KL = ∑ i p i log ( p i q i ) (6) L_{\text{KL}} = \sum_i p_i \log\left(\frac{p_i}{q_i}\right) \tag{6}LKL=i∑pilog(qipi)(6)- 当T = 1 T=1T=1时,KL散度退化为交叉熵损失。
总损失函数
L total = α ⋅ L KL + ( 1 − α ) ⋅ L CE (7) L_{\text{total}} = \alpha \cdot L_{\text{KL}} + (1-\alpha) \cdot L_{\text{CE}} \tag{7}Ltotal=α⋅LKL+(1−α)⋅LCE(7)- L CE L_{\text{CE}}LCE:学生模型与真实标签的交叉熵损失(硬标签损失)。
- α \alphaα:软标签损失的权重(通常取0.5-0.9)。
3.2 蒸馏计算示例
以三分类任务为例,演示损失计算过程:
步骤1:模型输出
- 教师模型logits:z teacher = [ 3.0 , 1.0 , 0.2 ] z_{\text{teacher}} = [3.0, 1.0, 0.2]zteacher=[3.0,1.0,0.2]
- 学生模型logits:z student = [ 2.5 , 0.8 , 0.1 ] z_{\text{student}} = [2.5, 0.8, 0.1]zstudent=[2.5,0.8,0.1]
- 真实标签:y = [ 1 , 0 , 0 ] y = [1, 0, 0]y=[1,0,0](第0类)
步骤2:生成软标签(T = 2.0 T=2.0T=2.0)
- 教师软标签:
p = [ exp ( 3 / 2 ) ∑ , exp ( 1 / 2 ) ∑ , exp ( 0.2 / 2 ) ∑ ] ≈ [ 0.721 , 0.215 , 0.064 ] p = \left[ \frac{\exp(3/2)}{\sum}, \frac{\exp(1/2)}{\sum}, \frac{\exp(0.2/2)}{\sum} \right] \approx [0.721, 0.215, 0.064]p=[∑exp(3/2),∑exp(1/2),∑exp(0.2/2)]≈[0.721,0.215,0.064] - 学生软标签:
q = [ exp ( 2.5 / 2 ) ∑ , exp ( 0.8 / 2 ) ∑ , exp ( 0.1 / 2 ) ∑ ] ≈ [ 0.659 , 0.257 , 0.084 ] q = \left[ \frac{\exp(2.5/2)}{\sum}, \frac{\exp(0.8/2)}{\sum}, \frac{\exp(0.1/2)}{\sum} \right] \approx [0.659, 0.257, 0.084]q=[∑exp(2.5/2),∑exp(0.8/2),∑exp(0.1/2)]≈[0.659,0.257,0.084]
步骤3:计算损失
- KL散度损失(L KL L_{\text{KL}}LKL):
L KL = 0.721 log ( 0.721 / 0.659 ) + 0.215 log ( 0.215 / 0.257 ) + 0.064 log ( 0.064 / 0.084 ) ≈ 0.018 L_{\text{KL}} = 0.721\log(0.721/0.659) + 0.215\log(0.215/0.257) + 0.064\log(0.064/0.084) \approx 0.018LKL=0.721log(0.721/0.659)+0.215log(0.215/0.257)+0.064log(0.064/0.084)≈0.018 - 硬标签损失(L CE L_{\text{CE}}LCE):
L CE = − log ( q 0 ) ≈ − log ( 0.659 ) ≈ 0.418 L_{\text{CE}} = -\log(q_0) \approx -\log(0.659) \approx 0.418LCE=−log(q0)≈−log(0.659)≈0.418 - 总损失(α = 0.7 \alpha=0.7α=0.7):
L total = 0.7 × 0.018 + 0.3 × 0.418 ≈ 0.138 L_{\text{total}} = 0.7 \times 0.018 + 0.3 \times 0.418 \approx 0.138Ltotal=0.7×0.018+0.3×0.418≈0.138
3.3 知识蒸馏相关API详解
- Hugging Face Transformers API
| API名称/工具 | 功能描述 | 关键参数说明 |
|---|---|---|
transformers.Trainer | 通过自定义损失函数实现蒸馏(教师模型固定,学生模型训练) | - model:学生模型- args:训练参数(如TrainingArguments)- compute_loss:自定义损失函数(融合KL散度和交叉熵) |
transformers.DistilBertForSequenceClassification | 预训练蒸馏模型(如DistilBERT,学生模型) | - 继承自PreTrainedModel,可直接加载预训练权重(如distilbert-base-uncased) |
- PyTorch蒸馏器具
| API名称 | 功能描述 | 关键参数说明 |
|---|---|---|
torch.nn.KLDivLoss | 计算KL散度损失(软标签损失) | - reduction:损失聚合方式(如'batchmean')- log_target:是否目标为对数形式 |
torch.nn.CrossEntropyLoss | 计算交叉熵损失(硬标签损失) | - weight:类别权重- reduction:损失聚合方式 |
- 专用蒸馏库
| 库名称 | 功能描述 | 核心API示例 |
|---|---|---|
HuggingFace/transformers中的蒸馏工具 | 提供DistilBERT、DistilRoBERTa等蒸馏模型的训练逻辑 | from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
knowledge-distillation-pytorch | 轻量级蒸馏库(支持多种蒸馏策略) | from kd import KnowledgeDistillationLoss(融合KL散度和硬标签损失) |
四、模型剪枝:移除冗余参数,保留核心能力

4.1 剪枝的数学原理
剪枝通过评估参数重要性移除冗余权重,常用L1范数衡量重要性(值越小越冗余)。
L1范数重要性评估
对于权重矩阵W ∈ R m × n W \in \mathbb{R}^{m \times n}W∈Rm×n,单个权重w i j w_{ij}wij的重要性为:
I ( w i j ) = ∣ w i j ∣ (8) \mathcal{I}(w_{ij}) = |w_{ij}| \tag{8}I(wij)=∣wij∣(8)全局剪枝阈值计算
若剪枝比例为r rr,则阈值θ \thetaθ满足:
∑ i , j I ( ∣ w i j ∣ < θ ) 总参数数 = r (9) \frac{\sum_{i,j} \mathbb{I}(|w_{ij}| < \theta)}{\text{总参数数}} = r \tag{9}总参数数∑i,jI(∣wij∣<θ)=r(9)- I ( ⋅ ) \mathbb{I}(\cdot)I(⋅)为指示函数,满足条件时取1。
4.2 剪枝计算示例
以3x3权重矩阵为例,剪枝30%的参数:
步骤1:原始权重矩阵
W = [ 0.1 − 0.02 0.05 − 0.3 0.01 0.2 0.03 − 0.04 0.08 ] W = \begin{bmatrix} 0.1 & -0.02 & 0.05 \\ -0.3 & 0.01 & 0.2 \\ 0.03 & -0.04 & 0.08 \end{bmatrix}W=0.1−0.30.03−0.020.01−0.040.050.20.08
步骤2:计算L1范数(重要性)
∣ I ( W ) ∣ = [ 0.1 0.02 0.05 0.3 0.01 0.2 0.03 0.04 0.08 ] |\mathcal{I}(W)| = \begin{bmatrix} 0.1 & 0.02 & 0.05 \\ 0.3 & 0.01 & 0.2 \\ 0.03 & 0.04 & 0.08 \end{bmatrix}∣I(W)∣=0.10.30.030.020.010.040.050.20.08
步骤3:排序并确定阈值
将所有权重按L1范数升序排列:0.01 , 0.02 , 0.03 , 0.04 , 0.05 , 0.08 , 0.1 , 0.2 , 0.3 0.01, 0.02, 0.03, 0.04, 0.05, 0.08, 0.1, 0.2, 0.30.01,0.02,0.03,0.04,0.05,0.08,0.1,0.2,0.3
总参数9个,剪枝30%即移除3个参数,阈值θ = 0.03 \theta=0.03θ=0.03(第3小的值)。
步骤4:剪枝后矩阵(小于θ \thetaθ的权重置0)
W pruned = [ 0.1 0 0.05 − 0.3 0 0.2 0 − 0.04 0.08 ] W_{\text{pruned}} = \begin{bmatrix} 0.1 & 0 & 0.05 \\ -0.3 & 0 & 0.2 \\ 0 & -0.04 & 0.08 \end{bmatrix}Wpruned=0.1−0.3000−0.040.050.20.08
- 稀疏度:3/9=33.3%(接近目标30%)。
4.3 模型剪枝相关API详解
- PyTorch剪枝API
| API名称 | 功能描述 | 关键参数说明 |
|---|---|---|
torch.nn.utils.prune.l1_unstructured | 对单个模块进行L1非结构化剪枝 | - module:待剪枝模块(如model.bert.encoder.layer[0].attention.self.query)- name:待剪枝参数名(如'weight')- amount:剪枝比例(如0.3) |
torch.nn.utils.prune.global_unstructured | 对多个模块进行全局非结构化剪枝(统一阈值) | - parameters:待剪枝参数列表(如[(module, 'weight')])- pruning_method:剪枝方法(如prune.L1Unstructured)- amount:剪枝比例 |
torch.nn.utils.prune.remove | 永久移除剪枝掩码(将0值权重保留在参数中) | - module:已剪枝模块- name:剪枝参数名 |
torch.nn.utils.prune.ln_structured | 对模块进行结构化剪枝(如按通道剪枝) | - n:剪枝维度(如0表示按输出通道剪枝)- amount:剪枝比例- pruning_method:重要性评估方法(如'l1_unstructured') |
- TensorFlow剪枝API
| API名称 | 功能描述 | 关键参数说明 |
|---|---|---|
tfmot.sparsity.keras.prune_low_magnitude | 对Keras模型进行 magnitude-based 剪枝 | - model:待剪枝模型- pruning_schedule:剪枝调度(如PolynomialDecay) |
tfmot.sparsity.keras.PolynomialDecay | 定义剪枝比例随训练步数的变化策略 | - initial_sparsity:初始稀疏度- final_sparsity:目标稀疏度- num_steps:总步数 |
- 第三方剪枝应用
| 工具名称 | 功能描述 | 核心特点 |
|---|---|---|
TorchPrune | 支持PyTorch模型的结构化和非结构化剪枝 | 提供剪枝后模型微调应用,帮助可视化剪枝效果 |
PruneTorch | 轻量级剪枝库(支持Transformer、ResNet等主流模型) | 完成简单,适合快速验证剪枝效果 |
4.4 剪枝注意事项
- 非结构化剪枝生成稀疏矩阵,需硬件支持(如NVIDIA的Sparse Tensor Core)才能加速,否则可能变慢。
- 结构化剪枝(如按通道)生成密集矩阵,无需特殊硬件,但剪枝比例过高会导致精度大幅下降。
- 剪枝后需微调模型(fine-tuning),恢复因剪枝丢失的性能(通常微调3-5个epoch即可)。
五、总结
模型压缩技术凭借数学原理与工程实现的结合,在精度与效率间取得平衡,其核心API为工业界部署献出了便捷工具:
- 量化:通过
torch.quantization.quantize_dynamic等API实现低精度转换,适合追求极致部署效率的场景,API使用简单但需注意精度权衡。 - 蒸馏:基于
KLDivLoss与CrossEntropyLoss组合,或使用DistilBERT等预训练蒸馏模型,适合需要保留高精度的小模型场景。 - 剪枝:通过
global_unstructured等API移除冗余参数,适合对模型大小敏感且可接受一定部署复杂度的场景。
实际应用中,可组合多种技术(如“剪枝+量化”)进一步提升压缩效果,例如先剪枝移除30%冗余参数,再量化为int8,可在精度损失5%以内实现模型体积缩减80%以上,推动大模型在边缘设备的落地。

浙公网安备 33010602011771号