训练时的显存优化

总览

HuggingFace 的这篇文章总结了一系列节约显存的方法,非常全面。

训练时显存占用的组成:

  • 模型参数
  • 优化器状态
  • 输入张量和其他临时张量
  • 激活值
  • 梯度
  • 通信缓冲

“激活值” 可能有点难理解。这是指像是 dropout 的 mask、LayerNorm 的 \(\mu\ \sigma^2\) 等,不是梯度但参加到梯度计算的张量。

除了用混合精度等方法降低整体显存占用,从 降低显存占用峰值 入手也是有效的。

融合 backward pass 和 optimizer step

通常的训练过程:计算 loss、反向传播、使用优化器 然后 清除梯度。

loss = loss_fn(model(inputs, targets))
loss.backward()
optimizer.step()
optimizer.zero_grad()

这就意味着,我们一次性计算了所有梯度,然后一并应用优化器参数更新。

如果能边算梯度边更新参数,就不需要用大量空间去存储梯度数据了。这就是融合 backward pass 和 optimizer step 的原理,能够有效降低显存占用峰值。

对于 PyTorch Lightning,需要借助 fsdp_overlap_step_with_backward 处理优化器逻辑:

from lightning.fabric.strategies.fsdp import fsdp_overlap_step_with_backward
optimizers = [Optimizer([p], ...) for p in model.parameters()]
...
for inputs, targets in epoch:
    loss = loss_fn(model(inputs), targets)
    with fsdp_overlap_step_with_backward(optimizers, model):
        loss.backward()

若需要原生 PyTorch 实现,可以借助 register_post_accumulate_grad_hook

optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}

def optimizer_hook(parameter) -> None:
  optimizer_dict[parameter].step()
  optimizer_dict[parameter].zero_grad()

# Register the hook onto every parameter
for p in model.parameters():
   p.register_post_accumulate_grad_hook(optimizer_hook)

# Now remember our previous ``train()`` function? Since the optimizer has been
# fused into the backward, we can remove the optimizer step and zero_grad calls.
def train(model):
  fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()

  loss = model.forward(fake_image)
  loss.sum().backward()

本节参考:

优化器的选择

AdamW 优化器最为常用,调参简单效果好。要说缺点,就是每个参数都需要额外 8 字节的显存。

Adafactor 优化器改变 Adam 的动量思路,将空间占用降低到了 4 字节。但实际使用中发现 Adafactor 可能会导致训练不稳定。

Bitsandbytes 库提供了一系列 8-bit 优化器。其实现的 AdamW8bit 只需占用 2 字节空间。

这个 issue 是包含各种优化器的 benchmark。可以看出,各优化器的训练损失都差不多。这么说,大胆使用 AdamW8bit 节省显存是个不错的主意。

对于参数少、激活多的网络(例如卷积网络),8-bit 优化器的效果不是很明显。

Bitsandbytes 库推荐在使用 8-bit 优化器训练 NLP 模型时,将 embedding 层换为 bitsandbytes.nn.StableEmbedding 以保证训练稳定性。对于其他不稳定的参数,也可以使用 这个文档 提到的方法对那些参数单独使用 32-bit 优化器。

这个知乎问题下 提到 8-bit 优化器可能会让模型容易过拟合。注意一下。

PyTorch Lightning 对 Bitsandbytes 库有支持,可以自动替换用上 Bitsandbytes 的 8-bit 线性层。具体可看官方文档

关闭优化器的 foreach

PyTorch 的优化器默认启用了一个叫 foreach 的 trick,能加快训练。但随之而来的是额外的优化器中间变量占用,会导致峰值显存占用变高。若要关闭 foreach,在定义优化器时传入参数 foreach=False 即可。

本节参考:

posted @ 2024-04-18 16:42  倒地  阅读(50)  评论(0编辑  收藏  举报