[PaperReading] PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel

PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel

link
时间:23.04
单位:Meta
被引次数:363
主页:
https://docs.pytorch.org/docs/stable/fsdp.html
https://engineering.fb.com/2021/07/15/open-source/fsdp/

TL;DR

Pytorch官方提供的大模型训练训练推理接口,垂直方向将不同Layer划分为不同unit组,水平方向将weights分到不同device,Forward与Backward时All-gather回来进行计算。

Method

模型整体拆法如Fitgure1所示,FSDP的sharding主要指得是水平方向将weights切分到不同device存储,垂直方向仅仅是逻辑上的分组,方便sharding的分组管理。

水平拆分参考官方blog里面的介绍图。相比与传统Data Parallel,FSDP每个device仅持有\(1/N\) unit组模型的参数切片,在Forward与Backward之前,需要先All-gather拿到该unit组的完整模型,计算完成将peer devices的参数扔掉。

System Design

Model Initialization

​​核心问题​​

  • 挑战1​​:传统PyTorch要求模型参数在初始化时全量驻留GPU内存,无法直接初始化超大模型(如参数>单卡显存)。
  • ​​挑战2​​:用户自定义的初始化逻辑(如参数依赖其他模块的参数)需在分片后保持正确性。

​​解决方案:Deferred Initialization(延迟初始化)​​

  • 假设备(Fake Device)​​:在初始化阶段,参数存储在虚拟设备(如meta设备)上,仅记录形状和初始化操作(如nn.init.kaiming_normal_),不分配实际存储,不做任何初始化动作。优势​​:避免OOM,支持任意规模模型的初始化。
  • ​​按需分片初始化​​:将模型分解为多个FSDP单元(unit),逐个单元移动到真实GPU,执行记录的初始化操作后立即分片。
  • 处理依赖​​:若单元A依赖单元B的参数,FSDP临时恢复B的完整参数,初始化完成后再次分片。
    以下是 ​​PyTorch FSDP 论文第三章《SYSTEM DESIGN》各小节的详细解析​​,结合算法原理、设计动机和实现细节进行说明:

Sharding Strategies(分片策略)​​

Full Sharding(完全分片)​​

  • 机制​​:参数沿(num_params / world_size)均匀分片,每卡仅保留局部分片。计算前通过AllGather恢复完整参数,计算后立即释放非本地分片(ReduceScatter聚合梯度)。
  • ​​问题​​:直接分片单个参数导致通信效率低(小张量过多)。
  • 方案​​:将多个参数拼接为连续的FlatParameter,分片后每卡获得等大的连续块。填充对齐​​:确保FlatParameter总长度可被world_size整除,最小化填充开销。优势​​:减少通信次数,提升带宽利用率(见图2实验)。
    FlatParameter示意图

图2表明:a.切片参数越均匀,通信成本越小;b.通信总数据量相同情况下,数据块越大通信成本越低;

Hybrid Sharding(混合分片)​​

  • 应用场景​​:模型略大于单卡显存,但Full Sharding通信开销又太大。
  • 设计​​:将设备分为多个​​分片组​​(Sharding Group)和​​复制组​​(Replication Group)。
  • ​​分片组内​​:参数分片存储,组内AllGather/ReduceScatter。
  • 复制组间​​:组间AllReduce同步梯度。
    优势​​:
  • 减少跨节点通信(如分片组限制在同一节点内)。
  • 支持灵活调整分片因子(sharding_factor),平衡内存与吞吐。

Communication Optimizations(通信优化)​​

Overlapping Communication and Computation​​ (通信与计算时间重叠)

  • 问题​​:AllGather在通信过程会阻塞计算。
  • 方案​​:a.使用​​独立CUDA流​​发起通信,与计算流并行;b.同步机制​​:通过work.wait()确保计算依赖的通信完成;

Backward/Forward Prefetching

我们知道FSDP在Forward与Backward前有All-Gather操作会引发通信延迟,Prefetching机制使得该通信操作提前进行,使得通信与计算时间重叠。有空Backward Prefetching的对比参考Figure6 b实验。

Gradient Accumulation(梯度累积)​​

  • 带通信​​:多步累积后ReduceScatter,内存友好。
  • 无通信​​:保留完整梯度(更高内存),减少通信次数。

Memory Management(内存管理)​​

  • 问题:​​CUDA缓存分配器​​:频繁分配/释放导致内存碎片,尤其在多流场景下。FSDP痛点​​:AllGather流(生产者)与计算流(消费者)竞争内存块,可能触发阻塞式cudaMalloc重试。
  • 解决方案:Rate Limiter(速率限制器)​​
    • 限制​​最大在途AllGather数量​​(默认2),避免生产者过快耗尽内存。
    • 通过torch.cuda.memory_stats()监控num_alloc_retries,动态调整限制。
      ​​效果​​:减少碎片化,提升吞吐(见图6c实验对比)。

Code && Implementation

暂无

Experiment

Figure6 b实验说明在2.8B及以下模型尺寸实验上,使用FSDP与DDP的性能接近。

  • Figure7c看出当GPU达到512时,平均每张卡的吞吐量会下降7%,说明此时的通信与计算开销已经不能很好的重叠了。
  • Figure7b的单卡吞吐量明显较低,这是因为在128GPUs外加GPT175B的情况下,反传过程更容易触发CUDA内存整理。

Q&A

Q: FSDP的垂直方向分组与Pipeline Parallelism有什么区别?
A: FSDP中的水平Unit拆分(如将模型按层分组)本质上是为了更灵活地控制参数分片(Sharding)的粒度,而非将不同Unit分配到不同设备组上串行执行。​​ 这与流水线并行(Pipeline Parallelism)的物理阶段划分有根本区别。

总结与思考

  • FSDP仍然属于数据并行,因为计算的时候还是All-Gather原始的全部参数来计算的,而不是像tensor parallel那样只拿weight的分片计算。

相关链接

https://zhuanlan.zhihu.com/p/694288870
https://docs.pytorch.org/docs/stable/fsdp.html
https://engineering.fb.com/2021/07/15/open-source/fsdp/

本工作受Zero-DP启发,可以对比来看两个工作。

posted @ 2025-06-07 18:44  fariver  阅读(162)  评论(0)    收藏  举报