AdamW优化器
对于大模型训练最常用的 Adam / AdamW 优化器来说,优化器状态占用的显存通常是模型参数本身大小的 2 倍到 3 倍。这是训练显存爆炸的“头号元凶”。
以下是详细的拆解和计算:
1. 核心结论:优化器里到底存了什么?
在混合精度训练(FP16 训练)中,AdamW 优化器内部主要维护了以下三样东西,它们都是 FP32(单精度,4字节) 格式:
- 主权重副本(Master Weights):
- 大小:与模型参数量一致。
- 原因:模型在显存里是 FP16(2字节)格式,但为了更新参数时不损失精度,优化器会维护一份 FP32(4字节)的“完整版”权重副本。
- 一阶动量(Momentum, m*m* ):
- 大小:与模型参数量一致。
- 原因:这是梯度的指数移动平均值(可以理解为“速度”),用于平滑更新方向。
- 二阶动量(Variance, v*v* ):
- 大小:与模型参数量一致。
- 原因:这是梯度平方的指数移动平均值(可以理解为“加速度”或“方差”),用于自适应调整学习率。
2. 直观对比:不同优化器的显存账单
假设我们有一个 70亿(7B)参数 的模型,我们来看看不同优化器会让显存增加多少:
表格
| 优化器类型 | 内部存储内容 | 额外显存占用 | 总占用倍数 | 说明 |
|---|---|---|---|---|
| SGD | 无 | 0 | 1倍 | 仅模型权重,不存额外状态(最省显存,但难收敛) |
| SGD + Momentum | 动量 | 1倍模型大小 | 2倍 | 多存一份动量 |
| Adam / AdamW | 主权重 + 动量 + 方差 | 3倍模型大小 | 4倍 | 最常用,但也最吃显存 |
注意:这里的“总占用倍数”是指相对于模型权重本身的大小。
例如:7B 模型权重本身占 14GB (FP16),用 AdamW 优化器后,仅优化器状态就要占 42GB (FP32)。
3. 算一笔账:7B 模型的显存去哪了?
让我们把模型参数、梯度和优化器放在一起,看看完整的显存账单(以 FP16 混合精度训练为例):
- 模型参数:14 GB (FP16)
- 梯度:14 GB (FP16)
- 优化器状态:42 GB (FP32)
- 其中:主权重副本 28GB + 动量 14GB + 方差 14GB
结论:
如果不做任何优化,光是静态的模型状态(参数+梯度+优化器)就需要 70 GB 显存。这还没算激活值(Activation,随 Batch Size 变化,可能高达几十 GB)。
这就是为什么单张 24GB 的显卡(如 4090)连 7B 模型的推理都费劲,更别说训练了。
浙公网安备 33010602011771号