使-LLM-推理速度提高-3-倍的训练目标

使 LLM 推理速度提高 3 倍的训练目标

原文:towardsdatascience.com/why-weve-been-optimizing-the-wrong-thing-in-llms-for-years/

引言

标准的大语言模型(LLMs)的训练目标是简单的:下一个标记预测(NTP)。通过最大化给定前文情况下立即后续标记 x[t+1] 的概率,模型实现了显著的流畅性和推理能力。

然而,这种方法实际上非常低效,因为模型在预测填充词(例如,“the”,“and”,“have”)和信息承载词(例如,“red”,“apple”,“lazy”)时必须花费相同数量的计算。这一点由于超过 50% 的英语单词都是填充词(Nordquist,2024 年)³ 而加剧。这提出了一个实际问题:是否所有单词都需要完整的推理周期来预测,或者模型在预测之前就已经在它们的隐藏状态中有了填充词?

MTP 的动机

最近的经验研究表明,transformers 能够处理不仅仅是立即的下一步。Pal 等人(2023 年)¹ 的研究表明,transformer 模型的内部表示通常在生成之前就编码了未来文本的轨迹。

为了说明这一点,研究人员进行了一项“移植”实验。他们从处理句子“Madison Square Garden is located in…”的模型中提取了隐藏状态——就在它即将预测下一个词为“New”之前。然后,他们将这个向量放入处理完全无关上下文的模型中,例如“Tell me something about…”。尽管提示与上下文无关,但模型自动回归地完成了句子“Tell me something about New York City。”这证实了模型不仅编码了下一个标记,而且编码了整个未来的序列。

为了利用 LLMs 的这种潜在能力,Meta FAIR 的研究人员(Gloeckle 等人,2024 年)² 提出了一种新颖的方法。他们不是将这种预见性视为一个意外的副产品,而是将其明确地用作训练目标。通过要求模型在每个位置同时预测“n”个未来标记,而不是只预测一个,他们有效地使模型能够向前看。作者们证明了多标记预测(MTP)范式在各种基准测试上取得了显著更强的性能,同时将推理速度提高了高达 3 倍,比基线快得多。

MTP 架构:并行预测

如果下一个几个标记的信息已经嵌入到 LLM 的当前隐藏状态中,那么问题就变成了架构性的:我们如何提前提取这些信息,而不增加与标准 NTP 相比的计算需求?

作者提出的架构旨在修改现有的 Transformer 主干以同时预测 n 个未来的标记。与仅针对立即下一个标记 (x[t+1]) 最小化交叉熵损失的标准的 NTP 范式不同,多标记预测 (MTP) 最小化 n 个不同输出头上的平均损失:

图片[θ]: 表示整个模型作为一个函数

为了实现这一点,作者将模型分为两个部分:

  1. 共享主干 (f[s]): 模型的主体是一个标准的 Transformer 主干,其任务是处理提示上下文 x[1:t]​ 并将其转换为信息丰富的全局表示 z[t]​,该表示将被用于所有后续预测。

  2. 独立头 (f[h_i]): 主干的输出被馈送到 n 个独立的头。每个头都有自己的 Transformer 层,并负责预测未来的偏移标记(例如,头 1 预测 t+1,头 2 预测 t+2,等等)。

最终,每个单独头的输出都传递到共享反嵌入层,该层实现为从模型的隐藏维度到词汇长度的简单线性投影。下面的图示用于总结 MTP 架构最重要的方面:

图片

(来源:作者)

模型只处理共享主干一次。然后,它按顺序激活每个头。对于步骤 4-6,它激活第一个头,计算其 logits,然后在步骤 6-8 中进行反向传播。以类似的方式激活头 2,然后是头 3 和头 4。

克服内存瓶颈

上文所述的架构提出了一个重大的工程挑战:GPU 内存利用率。

大型语言模型的词汇量 (V) 通常在 32k-256k 的范围内,这是一个天文数字。这使得词汇表中每个单词的原始预测分数,即输出 logits,也非常大。在标准的 NTP 设置中,模型只需要在每个步骤中实现这些 logits 一次,这使得它变得可行。然而,在 MTP 设置中,同时产生 n 个不同的大规模 logits 集合,这很容易超过 GPU 内存。这使得 MTP 方法对于研究人员来说不切实际,除非他们大幅减少批量大小,从而减慢整个训练过程。

作者通过序列前向/反向传递策略绕过这个瓶颈。而不是同时计算所有 n 个头的损失,训练循环按顺序遍历它们:

  1. 共享主干计算潜在状态 z[t]

  2. 模型计算头 1 的 logits,计算损失,在整个模型中反向传播梯度,并立即丢弃内存中的 logits。

  3. 然后它重复这个过程,对头 2、头 3 等进行操作。

通过在每个头计算后从内存中删除这些庞大的 logit 向量,训练过程的峰值内存使用量保持在 O(V),而不是 O(nV)。这使得 MTP 模型可以以与标准模型相似的批量大小进行训练。

关键设计选择

除了内存优化之外,作者还做出了两个具体的设计决策,这对于理解 MTP 的性能指标和科学有效性至关重要。

1. 参数平衡约束

在具有 n=4 头的 MTP 模型中,带有变压器骨干的四个额外头层导致参数增加。为了补偿这种增加,作者从模型的主干中移除了等量的层,使其变浅。这样做是为了确保 MTP 相对于基线在性能上的任何变化都可以完全归因于 MTP 架构本身,而不是模型参数的增加。

尽管 MTP 的主干较浅,但仍然优于基于标准 NTP 的模型,这一事实进一步证明了该架构的优点。

2. 头拓扑:并行与因果

作者还尝试了头部的排列方式,具体比较了两种方法:

  • 并行头: 这是上述标准 MTP 设计。在这个设计中,每个头仅根据共享状态 z[t] 预测其特定的未来令牌,而不会看到其他头的预测。

  • 因果头: 在这个设置中,头 2(预测 t+2)将接收头 1 的输出作为输入。这会在模型末尾创建一个“迷你自回归”链,允许每个头查看前一个头的状态。具有 n=4 因果头的 MTP 架构如下所示:

(来源:作者)

在因果设计中,头以顺序排列。这样做是为了让每个头知道前一个头预测了什么。

令人惊讶的是,并行设计表现更好。作者假设在具有因果头的结构中,共享的主干“变得懒惰”,依赖于头去推断序列信息。但是,通过迫使头独立行动,主干实际上被有效地迫使学习一个全局表示,这可以同时满足所有头。这正是模型能够进行未来规划的确切属性,这在推理任务中是必不可少的。

实验结果:改进的规模

作者对 MTP 模型与从 300M 到 13B 参数大小的标准 Next-Token Prediction (NTP) 基线进行了广泛的评估比较。

1. 多令牌预测的“缩放定律”

可以说,最有趣的发现是模型的性能与其规模成正比。对于 300M-1.3B 参数的小型模型,MTP 与 NTP 之间的差异可以忽略不计(很多时候 MTP 的表现更差)。但随着规模的增加,MTP 开始显著优于基线。如图所示,MTP 在 MBPP 基准测试上优于 NTP 17%,在 HumanEval 基准测试上优于 12%。

图片

(来源:改编自 Gloeckle 等人(2024b),图 3)

注意:这些图表显示了与基线相比的绝对点变化。例如,在左上角的图表中,13B NTP 模型在 MBPP 基准测试上得分为 26%,而 MTP 得分为 30.5%,在绝对意义上增加了 4.5 个百分点,在相对意义上增加了 17%。

这种差异的可能原因可能是,较大的模型由于拥有更多的参数数量,能够比较小的模型分配更多的容量用于未来规划。这使得较大的模型能够利用多令牌目标来发展更高级的推理能力。

2. 通过自我推测实现三倍推理速度提升

除了性能指标外,MTP 还解决了 LLM 操作中最持久的瓶颈之一:推理延迟。

要完全欣赏这一贡献,我们首先必须理解什么是推测性解码。在标准推理中,模型必须迭代地生成令牌。它必须等待 x[t] 生成后才能计算 x[t+1]。推测性解码通过使用较小的、较快的草稿模型(通常与主模型属于同一系列但参数较少)来加快这个过程,该模型接收来自主模型的隐藏状态并预测接下来的几个令牌。然后,主模型负责在单次前向传递中验证所有这些令牌,确保它与较小模型的预测一致。由于单次前向传递比通过多次迭代生成令牌要快,这导致净速度提升。(了解更多关于推测性解码的信息](https://medium.com/ai-science/speculative-decoding-make-llm-inference-faster-c004501af120))

推测性解码通常需要将较小的模型加载到内存中,这可能会很消耗内存。然而,作者提出,额外的 MTP 头(通常在训练后会被丢弃)可以用作内置草稿模型的角色。由于这些头共享相同的主体,这些头是高度准确的草稿师。通过使用多达四个头草拟子序列,并在并行中验证它,MTP 在推理中实现了 3 倍的速度提升,同时没有损失性能准确性。

4. “归纳头”的快速形成

作者们还分析了 MTP 中归纳能力的出现。归纳头是变压器中的电路,主要负责模式匹配能力(例如,识别[A]…[B]…[A]很可能是[B])。下方的图表显示,对于较小的模型尺寸,MTP 比同样大小的 NTP 模型显示出更强的归纳能力。这表明,通过迫使模型预测下一个立即的标记的后果,它产生了一个有助于模式识别和上下文学习的梯度信号。

(来源:改编自 Gloeckle 等人(2024b),图 7)

作者们选取了 100 个儿童故事,并将人物的名字替换为跨越两个标记的名字。y 轴上绘制的归纳成功是模型正确预测两个标记名字中第二个标记的准确率,前提是该名字至少被模型展示过一次。

5. 解锁字节级训练

在一个更激进的实验中,作者们将 MTP 应用于字节级模型,这些模型预测的是一系列字节而不是标记表示。历史上,字节级模型一直表现不佳,因为字节之间的上下文信息较弱,字节序列往往变得非常大。然而,如下表所示,使用n=8个头(一次预测 8 个字节),MTP 模型在所有三个基准测试中均显著优于具有n=1个头的基线 NTP 模型。这表明 MTP 模型可以有效地在字节领域导航,允许模型以原生方式处理原始数据,而不会在性能上做出任何妥协。

(来源:改编自 Gloeckle 等人(2024b),表 1)

此表展示了 MTP 和 NTP 模型在不同基准测试上的 Pass@k 准确率。例如,@10 列衡量的是模型生成的至少一个解决方案在前 10 个解决方案中是正确的概率。

预见之代价:不足和权衡

尽管多标记预测为标准范式提供了一个有吸引力的替代方案,但论文的结果明确指出,它不是一个通用的“银弹”。该架构引入了工程师必须考虑的具体权衡。

1. 知识密集型任务的回归

虽然 MTP 提高了推理(如何构建答案)的能力,但它似乎损害了检索(知道一个具体的事实)。

如下所示,MTP 模型在代码生成和推理基准测试中占主导地位,但实际上在标准 NLP 任务上表现不佳,包括像 MMLU、TriviaQA 和 ARC Challenge(测试事实检索和世界知识)这样的基准测试。

(来源:改编自 Gloeckle 等人(2024b),图 7)

在 y 轴上绘制了 7 个基准的平均准确率,即 arc challenge、copa、hellaswag、nq、piqa、siqa 和 tqa,与 x 轴上的训练步骤相对应。

一种可能的解释是,回答基于回忆的问题,如“法国的首都是什么?”需要精确关注“巴黎”这个词。通过迫使模型一次性预测多个标记,如在“巴黎是……的一个城市,”中,它可能会稀释来自最关键标记的整体信号,从而降低模型在整体基准上的性能。如果你的目标是构建一个 RAG(检索增强生成)系统或一个问答机器人,MTP 实际上可能是有害的。

2. n的“金发女孩”敏感性

这里没有“越多越好”的规则。作者发现,性能对头的数量(n)非常敏感。

作者还得出结论,头的数量(n)并不与 MTP 性能线性增长。存在一个“甜蜜点”,模型可以最有效地利用 MTP 范式:

  • 太少了(n=2):收益微乎其微,因为模型没有足够的激励去发展任何远见。

  • 太多了(n=8):性能迅速下降,因为所有 8 个头的信息开始过度拥挤共享树干的隐藏状态。

  • 正好合适(n=4):最佳性能

这引入了一个必须调整的新超参数。与 Next-Token Prediction 不同,后者只是“有效”,MTP 需要找到与你的数据复杂性相匹配的具体范围。

结论

通过其提高编码性能和推理速度的能力,一个明显的问题仍然存在:如果 MTP 如此革命性,为什么还没有任何主要的人工智能实验室使用它呢?

对此的答案是 DeepSeek-V3。

在他们的技术报告(Liu et al., 2024)⁴中,DeepSeek 团队揭示了 MTP 是模型训练的核心组件。类似于 Meta,他们在 15.7B 和 228.7B 参数规模上对标准 NTP 模型与 MTP 进行了彻底的消融研究。在训练过程中使用<em>n</em>=2配置(预测一个额外的未来标记),他们发现 MTP 训练的模型在所有数据集上,如 MMLU、PILE-test、HumanEval、MBPP 等,都一致优于其 NTP 对应模型。此外,通过在推理过程中保持那个第二个预测头进行如前所述的推测解码,DeepSeek 实现了高达 1.8 倍的推理速度提升。

DeepSeek 的成功部署为 MTP 作为大型语言模型训练目标的广泛应用提供了实际验证,因为它展示了通过最小相关缺点提高模型推理能力和推理效率的明确路径。

如果你喜欢这类分析,我在这里分享更多见解、笔记和解释:steadysurfdom.substack.com/

参考文献

[1] 帕尔,科耶纳,等. “未来镜头:从单个隐藏状态预测后续令牌.arXiv 预印本 arXiv:2311.04897 (2023).

[2] 格洛克勒,法比安,等. “通过多令牌预测构建更好、更快的语言模型.arXiv 预印本 arXiv:2404.19737 (2024).

[3] 诺德奎斯特,R. (2024, 7 月 20 日). [英语功能词的定义和示例. ThoughtCo.]( https://www.thoughtco.com/function-word-grammar-1690876#:~:text=According to social psychologist James ,is%20and%20how%20it%20works)

[4] 刘爱新,等. “Deepseek-v3 技术报告.arXiv 预印本 arXiv:2412.19437 (2024).

posted @ 2026-03-28 09:45  布客飞龙II  阅读(40)  评论(0)    收藏  举报