提升BERT模型效率与容量的新方法:Pyramid-BERT

近年来,自然语言处理(NLP)领域许多性能最佳的模型都建立在BERT语言模型之上。BERT模型在大规模(未标注)公共文本语料库上进行预训练,编码了单词序列的概率。由于BERT模型一开始就掌握了语言的整体知识,因此只需相对较少的标注数据,就能针对特定任务(如问答或机器翻译)进行微调。

然而,BERT模型体量庞大,基于BERT的NLP模型可能运行缓慢,对于计算资源有限的用户来说甚至慢到难以接受。其复杂性也限制了可处理的输入长度,因为其内存占用随输入长度的平方而增长。

在今年计算语言学协会(ACL)的会议上,本文作者及其同事提出了一种名为Pyramid-BERT的新方法。该方法能在几乎不损失准确性的前提下,减少基于BERT模型的训练时间、推理时间和内存占用。减少的内存占用也使BERT模型能够处理更长的文本序列。

基于BERT的模型将句子序列作为输入,并输出整个句子及其各个单词的向量表示(嵌入)。然而,文本分类和排序等下游应用仅使用完整的句子嵌入。为了使基于BERT的模型更高效,该方法尝试在网络的中间层逐步消除冗余的单个单词嵌入,同时尽量减少对完整句子嵌入的影响。

将Pyramid-BERT与几种最先进的BERT模型效率优化技术进行比较,结果显示,该方法能将推理速度提高3到3.5倍,而准确率仅下降1.5%;在相同速度下,现有最佳方法的准确率损失为2.5%。此外,当将此方法应用于专为长文本设计的BERT变体Performers时,能将模型的内存占用减少70%,同时甚至提高了准确性。在此压缩率下,现有最佳方法的准确率会下降4%。

标记的处理过程

输入BERT模型的每个句子都被分解为称为“标记”的单位。大多数标记是单词,但有些是多词短语、子词部分、缩写的单个字母等。每个句子的开头由一个特殊的标记(称为CLS)来标示。

每个标记通过一系列编码器(通常在4到12个之间),每个编码器为每个输入标记生成一个新的嵌入向量。每个编码器都有一个注意力机制,用于决定每个标记的嵌入应反映多少由其他标记携带的信息。

当标记通过一系列编码器时,它们的嵌入会包含越来越多关于序列中其他标记的信息,因为它们会关注那些同样也在整合越来越多信息的其他标记。当标记通过最终的编码器时,CLS标记的嵌入最终代表了整个句子。但它的嵌入也与句子中所有其他标记的嵌入非常相似。这正是该方法试图消除的冗余。

核心思路

基本思路是,在网络中的每个编码器层,保留CLS标记的嵌入,但从其他标记的嵌入中选择一个具有代表性的子集(即核心集)。

嵌入是向量,因此可以解释为多维空间中的点。理想情况下,为了构建核心集,我们会将嵌入分类为等直径的簇,并选择每个簇的中心点(质心)。

然而,构建一个跨越神经网络层的核心集问题是NP难问题,意味着其耗时将长得不切实际。

作为替代方案,该论文提出了一种贪心算法,每次从核心集中选择n个成员。在每一层,我们取CLS标记的嵌入,然后在表示空间中找到距离它最远的n个嵌入。我们将这些连同CLS嵌入一起添加到核心集中。接着,我们找到那些与核心集中已有任一点的最小距离最大的n个嵌入,并将它们添加到核心集中。

我们重复这个过程,直到核心集达到所需的大小。这被证明是足够接近最优核心集的近似。

最后,论文还探讨了每一层核心集应该多大。作者使用指数延迟函数来确定从一层到下一层的衰减程度,并研究了在选择不同的衰减率时,准确性与加速或内存减少之间的权衡关系。
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)或者 我的个人博客 https://blog.qife122.com/
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)

公众号二维码

公众号二维码

posted @ 2025-12-12 16:21  CodeShare  阅读(3)  评论(0)    收藏  举报