知识蒸馏
知识蒸馏
在大模型中,知识蒸馏是一种将大型、高精度教师模型中的关键知识提炼并传递给轻量化学生模型的技术。以下是关于知识蒸馏的详细解释:
- 原理:知识蒸馏的核心在于知识的传递和压缩。教师模型通常是参数众多的大型模型,经过充分训练后能捕捉数据中的丰富特征。学生模型则是轻量级模型,通过学习教师模型的输出实现性能优化。传统模型训练依赖硬标签,仅学习输入与输出的映射,而知识蒸馏引入教师模型的软标签,即教师模型经 softmax 处理的概率分布,包含类别间的相似性和置信度信息,学生模型通过学习软目标,能捕捉教师模型的推理逻辑,提升性能。
- 训练过程:首先需要训练一个性能强大的教师模型,然后准备与教师模型一致的训练或微调数据。在训练时,将输入同时送入教师模型和学生模型,得到对应输出,如 logits、内部层表示等。接着计算蒸馏损失,通常使用 KL 散度衡量学生模型和教师模型输出概率分布之间的差异,再结合监督损失,如交叉熵损失,通过反向传播更新学生模型参数,而教师模型通常是冻结的。
- 常见方法:
- Logits 蒸馏:这是最原始也是最常见的蒸馏方式,只关心模型输出层 logits 的模仿,教师的输出分布,即软标签,为学生模型提供了比真实标签更丰富的信息。(logits逻辑回归)
- 特征表示蒸馏:在 Transformer 中,不仅蒸馏最后的 logits,还蒸馏中间层的隐状态或注意力矩阵,让学生在层间细节上更接近教师,适合结构相似的模型。
- 多任务蒸馏:当教师模型是一个多任务或大规模预训练模型时,可以在多个任务数据或多语言数据上进行联合蒸馏,让学生继承教师在不同任务或语言上的知识。
- Progressive Distillation/ Layer - wise Distillation:若学生层数远少于教师层数,则可采用分层逐步蒸馏的策略,让学生更加稳定地学到教师的表征。
- Prompt 蒸馏:在大模型的指令微调或对话场景中,把教师的回答作为一个 “软目标”,让学生学习如何在相同指令下进行回答,使学生具备类似的对话能力,但规模更小。
- 作用:知识蒸馏可以显著降低模型的复杂度和计算量,提高模型的运行效率,加速推理,降低运行成本。同时,模型蒸馏还有可能帮助学生模型学习到教师模型中蕴含的泛化模式,提高其在未见过的数据上的表现,并且轻量化后的模型通常更加简洁明了,有利于理解和分析模型的决策过程,也更容易进行部署和应用。
Logits 蒸馏是知识蒸馏的一种经典方法,其核心思想是让学生模型学习教师模型输出的 Logits 信息,从而实现知识迁移。以下是对 Logits 蒸馏的详细介绍:
Logits 的定义
在分类问题中,输入数据经过神经网络的各种非线性变换后,在网络接近最后一层时,会得到一个向量,该向量中的每个元素代表输入数据属于各个类别的汇总分值,这个向量就是 Logits。Logits 并非概率值,通常需要通过 Softmax 函数将其转换为概率值,以作为最终的分类结果概率。
Logits 蒸馏的基本方法
假设存在一个教师网络和一个学生网络,当输入同一个数据时,教师网络会得到一个 Logits 向量,学生网络也会得到一个 Logits 向量。最早的知识蒸馏工作就是让学生的 Logits 去拟合教师的 Logits,此时学生的损失函数为教师 Logits 和学生 Logits 之间的均方误差等距离度量函数。
Hinton 提出的改进方法 ——Softmax Temperature
Hinton 在论文《Distilling the Knowledge in a Neural Network》中提出了称为 Softmax Temperature 的改进方法。该方法对 Softmax 函数进行了改造,引入了温度 T 这一超参数。当 T 设置为 1 时,就是标准的 Softmax 函数;当 T 设大时,Softmax 之后的 Logits 数值各个类别之间的概率分值差距会缩小,强化了非最大类别的存在感;反之,T 设小时,则会加大类别间概率的两极分化。 Hinton 版本的知识蒸馏让学生去拟合教师经过 T 影响后 Softmax 得到的概率分布,学生的损失函数由两项组成,一个子项是 Ground Truth 的标准交叉熵损失,用于让学生拟合训练数据,另一个是蒸馏损失,用于让学生拟合教师的 Logits,公式为: \(L = \alpha H(y_{gt}, f_s(x)) + (1 - \alpha) D_{KL}(p_T, p_S)\) 其中,H是交叉熵损失函数,\(f_s(x)\)是学生模型的映射函数,\(y_{gt}\)是 Ground Truth Label,\(p_T\)是教师的 Logits 经过 Softmax Temperature 函数处理后的概率分布,\(p_S\)是学生的 Logits 经过 Softmax Temperature 函数处理后的概率分布,\(\alpha\)用于调节蒸馏 Loss 的影响程度。
Logits 蒸馏的优势
基于 Logits 的知识蒸馏方法直接利用模型对样本的预测输出,无需关注神经网络模型的内在结构或特征表达,简单而高效,适用于各种学习任务,包括普通的监督学习、涉及不同领域和模态的学习任务等。并且,它与其他知识蒸馏方法的组合十分灵活,无需额外的设计,进一步提高了其适用性和实用性。

浙公网安备 33010602011771号