torch.distributed 概述

Pytorch distributed 概述

本节我们介绍一下 torch.distributed

Pytorch 分布式库主要包含一套并行的模块,一个通信层,以及对于运行和debug大规模训练的infra

主要有以下四个并行的apis:

  • DDP(分布式数据并行)
  • FSDP (fully sharded data-parallel training)
  • Tensor parallel(tp)
  • pipeline parallel(pp)

分片原语:

DTensorDeviceMesh 是可以根据在N维的进程分组进行构建来开启并行。

  • DTensor: 表示一个 sharded and/or replicated 的tensor,可以根据操作自动地reshard tensor
  • DeviceMesh: 将 device communicator 抽象为 一个多维数组,可以管理底层的 ProcessGroup 实例 来在一个多维的并行上进行集合通信。

通信api:

pytorch分布式通信层(c10d)提供了集合通信api(例如 all_reduce, all_gather) 以及 P2P 的api (例如send和isend)

launcher

torchrun是一个通常使用的launch脚背,可以在本地和远程机器上spawns 进程来运行分布式的pytorch程序

应用并行来scale你的模型

数据并行:模型被复制到每个进程上

模型并行:模型被放进一个GPU内

  1. 如果你的模型能放入一个GPU,想使用多GPU进行scale,那就使用DDP.
    • 如果使用了多个节点,用torchrun来launch多个pytorch进程
  2. 如果不能放进GPU,那就使用 FSDP
  3. 如果到达了FSDP的scale极限,使用tp 及 pp
posted @ 2025-03-20 17:47  xwher  阅读(112)  评论(0)    收藏  举报