AdamW优化器

对于大模型训练最常用的 Adam / AdamW 优化器来说,优化器状态占用的显存通常是模型参数本身大小的 2 倍到 3 倍。这是训练显存爆炸的“头号元凶”。

以下是详细的拆解和计算:

1. 核心结论:优化器里到底存了什么?

在混合精度训练(FP16 训练)中,AdamW 优化器内部主要维护了以下三样东西,它们都是 FP32(单精度,4字节) 格式:

  1. 主权重副本(Master Weights)
    • 大小:与模型参数量一致。
    • 原因:模型在显存里是 FP16(2字节)格式,但为了更新参数时不损失精度,优化器会维护一份 FP32(4字节)的“完整版”权重副本。
  2. 一阶动量(Momentum, m*m* )
    • 大小:与模型参数量一致。
    • 原因:这是梯度的指数移动平均值(可以理解为“速度”),用于平滑更新方向。
  3. 二阶动量(Variance, v*v* )
    • 大小:与模型参数量一致。
    • 原因:这是梯度平方的指数移动平均值(可以理解为“加速度”或“方差”),用于自适应调整学习率。

2. 直观对比:不同优化器的显存账单

假设我们有一个 70亿(7B)参数 的模型,我们来看看不同优化器会让显存增加多少:

表格

优化器类型 内部存储内容 额外显存占用 总占用倍数 说明
SGD 0 1倍 仅模型权重,不存额外状态(最省显存,但难收敛)
SGD + Momentum 动量 1倍模型大小 2倍 多存一份动量
Adam / AdamW 主权重 + 动量 + 方差 3倍模型大小 4倍 最常用,但也最吃显存

注意:这里的“总占用倍数”是指相对于模型权重本身的大小。
例如:7B 模型权重本身占 14GB (FP16),用 AdamW 优化器后,仅优化器状态就要占 42GB (FP32)。

3. 算一笔账:7B 模型的显存去哪了?

让我们把模型参数、梯度和优化器放在一起,看看完整的显存账单(以 FP16 混合精度训练为例):

  • 模型参数:14 GB (FP16)
  • 梯度:14 GB (FP16)
  • 优化器状态42 GB (FP32)
    • 其中:主权重副本 28GB + 动量 14GB + 方差 14GB

结论
如果不做任何优化,光是静态的模型状态(参数+梯度+优化器)就需要 70 GB 显存。这还没算激活值(Activation,随 Batch Size 变化,可能高达几十 GB)。

这就是为什么单张 24GB 的显卡(如 4090)连 7B 模型的推理都费劲,更别说训练了。