为什么必须走上“分布式训练”之路?

模型“体积”在爆炸

从 CV、NLP 到语音、推荐,深度学习已无处不在。ChatGPT 的出圈正式拉开“大模型”帷幕:

  • GPT-3 1 750 亿参数,用 1 024 张 80 GB A100 完整训练仍需 ≈1 个月;
  • 训练算力需求每 2 年≈15× 增长,Transformer 类更达 750×;
  • 模型大小平均每 2 年≈240× 膨胀。

分布式训练 = 唯一可扩展路线

只有把“成百上千”张 GPU/TPU/ASIC 组织成集群,才能同时突破算力、内存两重天花板:

  • 数据并行:批量拆分到多卡,同步梯度,线性提升吞吐;
  • 模型并行 / 流水线并行:把网络层或参数矩阵切分到不同设备,解决“放不下”问题;
  • 多维混合并行:在节点内模型并行、节点间数据并行,再叠加 ZeRO/张量并行,实现“能放又能跑”。

数据并行

核心思想:在各个GPU上都拷贝一份完整模型,各自吃一份数据,算一份梯度,最后对梯度进行累加来更新整体模型。理念不复杂,但到了大模型场景,巨大的存储和GPU间的通讯量,就是系统设计要考虑的重点了。在本文中,我们将递进介绍三种主流数据并行的实现方式:

  • DP(Data Parallelism):最早的数据并行模式,一般采用参数服务器(Parameters Server)这一编程框架。实际中多用于单机多卡。
  • DDP(Distributed Data Parallelism):分布式数据并行,采用Ring AllReduce的通讯方式,实际中多用于多机场景。
  • ZeRO:零冗余优化器。由微软推出并应用于其DeepSpeed框架中。严格来讲ZeRO采用数据并行+张量并行的方式,旨在降低存储。

DP(Data Parallel)

DP流程

  1. 若干块计算GPU,如图中GPU0~GPU2;1块梯度收集GPU,如图中AllReduce操作所在GPU。
  2. 在每块计算GPU上都拷贝一份完整的模型参数。
  3. 把一份数据X(例如一个batch)均匀分给不同的计算GPU。
  4. 每块计算GPU做一轮FWD和BWD后,算得一份梯度G。
  5. 每块计算GPU将自己的梯度push给梯度收集GPU,做聚合操作。这里的聚合操作一般指梯度累加。当然也支持用户自定义。
  6. 梯度收集GPU聚合完毕后,计算GPU从它那pull下完整的梯度结果,用于更新模型参数W。更新完毕后,计算GPU上的模型参数依然保持一致。

聚合再下发梯度的操作,称为AllReduce(集体通信)。

在这里插入图片描述

通讯瓶颈

DP的框架理解起来不难,但实战中确有两个主要问题:

  • 存储开销大。每块GPU上都存了一份完整的模型,造成冗余。关于这一点的优化,我们将在后文ZeRO部分做讲解。
  • 通讯开销大。Server需要和每一个Worker进行梯度传输。当Server和Worker不在一台机器上时,Server的带宽将会成为整个系统的计算效率瓶颈。

我们对通讯开销再做详细说明。如果将传输比作一条马路,带宽就是马路的宽度,它决定每次并排行驶的数据量。例如带宽是100G/s,但每秒却推给Server 1000G的数据,消化肯定需要时间。那么当Server在搬运数据,计算梯度的时候,Worker们就会摸鱼。

DDP(Distributed Data Parallel)

受通信负载不均的制约,传统的 Data Parallel(DP)通常只用于单机多卡场景。为此,Distributed Data Parallel(DDP)作为更通用的解决方案应运而生,既支持单机,也支持多机训练。

DDP 首先要解决的是通信瓶颈:把原本集中在 Server 上的通信压力均衡地分散到各个 Worker 节点。实现这一目标后,系统甚至可以完全去掉 Server,只保留 Worker。

前文提到,"聚合梯度再下发梯度"这一完整操作称为 AllReduce。接下来介绍目前最通用的 AllReduce 实现——Ring-AllReduce。该算法由百度率先提出,有效解决了数据并行中的通信负载不均问题,为 DDP 的多机扩展奠定了基础。

Ring-AllReduce

不同于 Parameter Server 模式的一种AllReduce方法。如下图,假设有4块GPU,每块GPU上的数据也对应被切成4份。AllReduce 的最终目标,就是让每块GPU上的数据都变成箭头右边汇总的样子。
在这里插入图片描述

Ring-ALLReduce则分两大步骤实现该目标:Reduce-ScatterAll-Gather

Reduce-Scatter

定义网络拓扑关系,使得每个GPU只和其相邻的两块GPU通讯。每次发送对应位置的数据进行累加。每一次累加更新都形成一个拓扑环,因此被称为Ring

初始状态:
GPU0: [A0, B0, C0, D0]
GPU1: [A1, B1, C1, D1]
GPU2: [A2, B2, C2, D2]
GPU3: [A3, B3, C3, D3]

在这里插入图片描述
在这里插入图片描述

一次累加完毕后,蓝色位置的数据块被更新,被更新的数据块将成为下一次更新的起点,继续做累加操作。

第一次迭代后:
GPU0: [A0, B0, C0, D0+A3]
GPU1: [A1, B1, C1, D1+D0]
GPU2: [A2, B2, C2, D2+D1]
GPU3: [A3, B3, C3, D3+D2]

在这里插入图片描述
在这里插入图片描述

3次更新之后,每块GPU上都有一块数据拥有了对应位置完整的聚合(图中红色)。此时,Reduce-Scatter阶段结束。进入All-Gather阶段,目标是把红色块的数据广播到其余GPU对应的位置上。

All-Gather

如名字里Gather所述的一样,这操作里依然按照"相邻GPU对应位置进行通讯"的原则,但对应位置数据不再做相加,而是直接替换。All-Gather以红色块作为起点。

第一次All-Gather:
GPU0: [A0, B0, C0, D0+D3+D2+D1]
GPU1: [A1, B1, C1, D1+D0+A3+A2]
GPU2: [A2, B2, C2, D2+D1+D0+A3]
GPU3: [A3, B3, C3, D3+D2+D1+D0]

在这里插入图片描述
在这里插入图片描述

以此类推,根据Reduce-Scatter的结果进行三次迭代,即可完成。
在这里插入图片描述

总结

  1. 在DP中,每个GPU上都拷贝一份完整的模型,每个GPU上处理batch的一部分数据,所有GPU算出来的梯度进行累加后,再传回各GPU用于更新参数。
  2. DP多采用参数服务器这一编程框架,一般由若个计算Worker和1个梯度聚合Server组成。Server与每个Worker通讯,Worker间并不通讯。因此Server承担了系统所有的通讯压力。基于此DP常用于单机多卡场景。
  3. Ring-AllReduce通过定义网络环拓扑的方式,将通讯压力均衡地分到每个GPU上,使得跨机器的数据并行(DDP)得以高效实现。
  4. DP和DDP的总通讯量相同,但因负载不均的原因,DP需要耗费更多的时间搬运数据。