详细介绍:PyTorch API 3 - distributed
文章目录
分布式通信包 - torch.distributed
注意:关于分布式训练相关功能的简要介绍,请参阅PyTorch分布式概述。
后端支持
torch.distributed 支持三种内置后端,每种后端具有不同的功能特性。下表展示了哪些功能可用于 CPU/CUDA 张量。
 注意:MPI 仅在用于构建 PyTorch 的实现支持 CUDA 时,才能启用 CUDA 功能。
| 后端 | gloo | mpi | nccl | |||
|---|---|---|---|---|---|---|
| 设备类型 | CPU | GPU | CPU | GPU | CPU | GPU | 
| 发送 | ✓ | ✘ | ✓ | ? | ✘ | ✓ | 
| 接收 | ✓ | ✘ | ✓ | ? | ✘ | ✓ | 
| 广播 | ✓ | ✓ | ✓ | ? | ✘ | ✓ | 
| 全归约 | ✓ | ✓ | ✓ | ? | ✘ | ✓ | 
| 归约 | ✓ | ✘ | ✓ | ? | ✘ | ✓ | 
| 全收集 | ✓ | ✘ | ✓ | ? | ✘ | ✓ | 
| 收集 | ✓ | ✘ | ✓ | ? | ✘ | ✓ | 
| 分散 | ✓ | ✘ | ✓ | ? | ✘ | ✓ | 
| 归约分散 | ✘ | ✘ | ✘ | ✘ | ✘ | ✓ | 
| 全到全 | ✘ | ✘ | ✓ | ? | ✘ | ✓ | 
| 屏障 | ✓ | ✘ | ✓ | ? | ✘ | ✓ | 
PyTorch 内置的后端
PyTorch 分布式包支持 Linux(稳定版)、MacOS(稳定版)和 Windows(原型版)。在 Linux 平台上,默认会构建并包含 Gloo 和 NCCL 后端(NCCL 仅在 CUDA 环境下构建时包含)。MPI 是一个可选后端,只有从源码构建 PyTorch 时才能包含(例如在已安装 MPI 的主机上构建 PyTorch)。
注意:从 PyTorch v1.8 开始,Windows 支持除 NCCL 之外的所有集体通信后端。如果 init_process_group() 的 init_method 参数指向文件,则必须遵循以下格式:
- 本地文件系统:init_method="file:///d:/tmp/some_file"
- 共享文件系统:init_method="file://////{machine_name}/{share_folder_name}/some_file"
与 Linux 平台相同,您可以通过设置环境变量 MASTER_ADDR 和 MASTER_PORT 来启用 TcpStore。
选择哪个后端?
过去我们经常被问到:“我应该使用哪个后端?”
- 经验法则 
 - 分布式 GPU 训练使用 NCCL 后端
- 分布式 CPU 训练使用 Gloo 后端
 
- 配备 InfiniBand 互连的 GPU 主机 
 - 使用 NCCL,因为它是目前唯一支持 InfiniBand 和 GPUDirect 的后端
 
- 配备以太网互连的 GPU 主机 
 - 使用 NCCL,因为它目前能提供最佳的分布式 GPU 训练性能,尤其适用于多进程单节点或多节点分布式训练。如果遇到 NCCL 相关问题,可将 Gloo 作为备选方案。(注意:当前 Gloo 在 GPU 上的运行速度慢于 NCCL)
 
- 配备 InfiniBand 互连的 CPU 主机 
 - 若 InfiniBand 已启用 IP over IB 功能则使用 Gloo,否则改用 MPI。我们计划在后续版本中为 Gloo 添加 InfiniBand 支持
 
- 配备以太网互连的 CPU 主机 
 - 除非有特殊需求需使用 MPI,否则默认选择 Gloo
 
常见环境变量
选择使用的网络接口
默认情况下,NCCL 和 Gloo 后端都会尝试自动选择合适的网络接口。如果自动检测的接口不正确,可以通过以下环境变量手动指定(分别对应各自的后端):
- NCCL_SOCKET_IFNAME,例如- export NCCL_SOCKET_IFNAME=eth0
- GLOO_SOCKET_IFNAME,例如- export GLOO_SOCKET_IFNAME=eth0
如果使用 Gloo 后端,可以通过逗号分隔指定多个接口,例如:export GLOO_SOCKET_IFNAME=eth0,eth1,eth2,eth3。后端会以轮询方式在这些接口间分配操作。必须确保所有进程在该变量中指定相同数量的接口。
其他NCCL环境变量
调试功能 - 当NCCL出现故障时,可设置NCCL_DEBUG=INFO来打印明确的警告信息以及基础的NCCL初始化信息。
您还可以使用NCCL_DEBUG_SUBSYS获取NCCL特定模块的详细日志。例如,设置NCCL_DEBUG_SUBSYS=COLL将打印集合通信调用的日志,这对调试卡死问题(特别是由集合操作类型或消息大小不匹配引发的问题)很有帮助。若遇拓扑结构检测失败的情况,设置NCCL_DEBUG_SUBSYS=GRAPH可查看详细检测结果,如需NCCL团队进一步协助,该日志可作为参考依据保存。
性能调优 - NCCL基于拓扑检测结果进行自动调优以减少用户工作量。在某些基于socket的系统中,用户仍可尝试调整NCCL_SOCKET_NTHREADS和NCCL_NSOCKS_PERTHREAD来提升socket网络带宽。这两个环境变量已在AWS、GCP等云服务商环境中经过NCCL预调优。
完整NCCL环境变量列表请参阅NVIDIA NCCL官方文档
基础概念
torch.distributed 包为 PyTorch 提供了跨多个计算节点(运行在一台或多台机器上)的多进程并行支持及通信原语。torch.nn.parallel.DistributedDataParallel() 类基于此功能,通过封装任意 PyTorch 模型来提供同步分布式训练。这与 Multiprocessing package - torch.multiprocessing 和 torch.nn.DataParallel() 提供的并行方式不同,因为它支持多台网络连接的机器,并且需要用户显式地为每个进程启动主训练脚本的独立副本。
在单机同步场景下,torch.distributed 或 torch.nn.parallel.DistributedDataParallel() 封装器相比其他数据并行方法(包括 torch.nn.DataParallel())仍具有优势:
- 独立优化器:每个进程维护自己的优化器,并在每次迭代中执行完整的优化步骤。虽然这看似冗余(因为梯度已在进程间收集并平均,各进程梯度相同),但省去了参数广播步骤,从而减少了节点间张量传输的时间开销。
- 独立 Python 解释器:每个进程拥有独立的 Python 解释器,避免了单 Python 进程中驱动多个执行线程、模型副本或 GPU 时产生的额外解释器开销和 “GIL 争用”。这对于重度依赖 Python 运行时的模型(如包含循环层或大量小组件的模型)尤为重要。
初始化
在使用其他方法之前,需要通过 torch.distributed.init_process_group() 或 torch.distributed.device_mesh.init_device_mesh() 函数初始化该包。这两个函数都会阻塞,直到所有进程都加入为止。
警告:初始化操作不是线程安全的。进程组的创建应在单一线程中执行,以防止不同进程间出现不一致的 ‘UUID’ 分配,并避免初始化期间的竞争条件导致程序挂起。
torch.distributed.is_available()如果分布式包可用则返回 True。
否则,torch.distributed 不会暴露任何其他 API。目前 torch.distributed 在 Linux、MacOS 和 Windows 平台上可用。若要从源码构建 PyTorch 时启用该功能,需设置:
USE_DISTRIBUTED=1
当前默认值为:Linux 和 Windows 系统下 USE_DISTRIBUTED=1,MacOS 系统下 USE_DISTRIBUTED=0。
返回类型:bool
torch.distributed.init_process_group(backend=None, init_method=None, timeout=None, world_size=-1, rank=-1, store=None, group_name='', pg_options=None, device_id=None)初始化默认的分布式进程组。
这将同时初始化分布式包。
初始化进程组主要有两种方式:
1、显式指定 store、rank 和 world_size
2、指定 init_method(URL字符串)来指示如何发现对等节点。可选指定 rank 和 world_size,或将所有必需参数编码在URL中并省略它们
如果均未指定,则默认 init_method 为 “env://”。
参数说明
- backend (str 或 Backend, 可选)- 使用的后端。根据构建配置,有效值包括 mpi、gloo、nccl、ucc 或第三方插件注册的后端。从 2.6 版本开始,若未提供 backend,c10d 将根据 device_id 参数(如提供)对应的设备类型使用注册的后端。当前已知的默认注册为:cuda 设备使用 nccl,cpu 设备使用 gloo。若 backend 和 device_id 均未提供,c10d 将自动检测运行机器的加速器并使用对应注册的后端(或 cpu)。该字段可接受小写字符串(如 “gloo”),也可通过 Backend 属性访问(如 Backend.GLOO)。注意:使用 nccl 后端时,若单机多进程,每个进程必须独占其使用的 GPU,进程间共享 GPU 可能导致死锁或 NCCL 非法使用。ucc 后端为实验性功能。
- init_method (str, 可选)- 指定进程组初始化方式的 URL。若未指定 init_method 或 store,默认为 “env://”。与 store 参数互斥。
- world_size (int, 可选)- 参与任务的进程总数。若指定 store 则必须提供。
- rank (int, 可选)- 当前进程的排名(取值范围应为 0 到 world_size-1)。若指定 store 则必须提供。
- store (Store, 可选)- 所有工作进程可访问的键值存储,用于交换连接/地址信息。与 init_method 互斥。
- timeout (timedelta, 可选)- 进程组操作的超时时间。NCCL 默认 10 分钟,其他后端默认 30 分钟。超时后异步中止集合操作并终止进程。由于 CUDA 执行是异步的,继续执行用户代码可能不安全,因为失败的异步 NCCL 操作可能导致后续 CUDA 操作处理损坏数据。当设置 TORCH_NCCL_BLOCKING_WAIT 时,进程将阻塞等待此超时。
- group_name (str, 可选, 已弃用)- 组名(该参数已被忽略)
- pg_options (ProcessGroupOptions, 可选)- 进程组选项,用于在构建特定进程组时传递额外参数。目前仅支持 nccl 后端的 ProcessGroupNCCL.Options,可指定 is_high_priority_stream 让 nccl 后端在有计算内核等待时选择高优先级 CUDA 流。其他可配置 NCCL 的选项参见:https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
- device_id (torch.device, 可选)- 绑定进程的特定设备,支持后端特定优化。目前仅在 NCCL 下有两个效果:立即形成通信器(直接调用 ncclCommInit* 而非延迟调用),子组尽可能使用 ncclCommSplit 以避免不必要的组创建开销。如需提前获知 NCCL 初始化错误,也可使用此字段。
注意事项
启用 backend == Backend.MPI 需在支持 MPI 的系统上从源码编译 PyTorch。
实验性说明
多后端支持目前处于实验阶段。未指定 backend 时,将同时创建 gloo 和 nccl 后端:CPU 张量的集合操作使用 gloo,CUDA 张量的集合操作使用 nccl。可通过格式为 “<设备类型>:<后端名称>,<设备类型>:<后端名称>” 的字符串指定自定义后端,例如:“cpu:gloo,cuda:custom_backend”。
torch.distributed.device_mesh.init_device_mesh(device_type, mesh_shape, *, mesh_dim_names=None)根据device_type、mesh_shape和mesh_dim_names参数初始化一个DeviceMesh。
这会创建一个具有n维数组布局的DeviceMesh,其中n是mesh_shape的长度。
如果提供了mesh_dim_names,则每个维度会被标记为mesh_dim_names[i]。
注意:init_device_mesh遵循SPMD编程模型,意味着相同的PyTorch Python程序会在集群中的所有进程/rank上运行。请确保mesh_shape(描述设备布局的n维数组的维度)在所有rank上保持一致。不一致的mesh_shape可能导致程序挂起。
注意:如果找不到进程组,init_device_mesh会在后台初始化分布式通信所需的分布式进程组/组。
参数
- device_type (str)- 网格的设备类型。当前支持:“cpu”、“cuda/cuda-like”。不允许传入带有GPU索引的设备类型,如"cuda:0"。
- mesh_shape (Tuple[int])- 定义描述设备布局的多维数组维度的元组。
- mesh_dim_names (Tuple[str], 可选)- 分配给描述设备布局的多维数组每个维度的网格维度名称元组。其长度必须与- mesh_shape的长度匹配。- mesh_dim_names中的每个字符串必须是唯一的。
返回
一个表示设备布局的DeviceMesh对象。
返回类型
示例:
>>
>
from torch.distributed.device_mesh import init_device_mesh
>>
>
>
>>
> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
>>
> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))torch.distributed.is_initialized()检查默认进程组是否已初始化。
返回类型:bool
torch.distributed.is_mpi_available()检查 MPI 后端是否可用。
返回类型:bool
torch.distributed.is_nccl_available()检查NCCL后端是否可用。
返回类型:bool
torch.distributed.is_gloo_available()检查 Gloo 后端是否可用。
返回类型:bool
torch.distributed.distributed_c10d.is_xccl_available()检查XCCL后端是否可用。
返回类型:bool
torch.distributed.is_torchelastic_launched()检查当前进程是否通过 torch.distributed.elastic(即 torchelastic)启动。
通过检测环境变量 TORCHELASTIC_RUN_ID 是否存在作为判断依据。这是一个合理的代理指标,因为 TORCHELASTIC_RUN_ID 映射到 rendezvous id(该值始终为非空,用于标识作业ID以实现节点发现)。
返回类型:bool
目前支持三种初始化方法:
TCP初始化
有两种使用TCP进行初始化的方式,两者都需要一个所有进程均可访问的网络地址和指定的world_size。第一种方式要求指定一个属于rank 0进程的地址。这种初始化方法要求所有进程都手动指定rank。
请注意,最新版本的分布式包不再支持多播地址。group_name参数也已弃用。
import torch.distributed as dist
# Use address of one of the machines
dist.init_process_group(backend, init_method='tcp://10.1.1.20:23456', rank=args.rank, world_size=4)共享文件系统初始化
另一种初始化方法利用了组内所有机器均可访问的共享文件系统,并配合指定的world_size参数。URL应以file://开头,并指向共享文件系统中某个不存在文件(位于已存在的目录)的路径。文件系统初始化会自动创建该文件(若不存在),但不会删除文件。因此,您需要确保在下一次对相同文件路径/名称调用init_process_group()前清理该文件。
请注意,最新版分布式包已不再支持自动分配rank,同时group_name参数也已弃用。
警告:此方法假设文件系统支持通过fcntl进行锁定——大多数本地系统和NFS都支持此功能。
警告:此方法总会创建文件,并会在程序结束时尽力清理和删除文件。换句话说,每次使用文件初始化方法时都需要一个全新的空文件才能成功初始化。如果重复使用前次初始化未清理的同一文件,将导致意外行为,通常会造成死锁和故障。因此,尽管该方法会尽力清理文件,但如果自动删除失败,您必须确保在训练结束后删除该文件,以防下次重复使用同一文件。当您计划对同一文件名多次调用init_process_group()时,这一点尤为重要。
简而言之,如果文件未被移除/清理,而您再次对该文件调用init_process_group(),预期会发生故障。经验法则是:确保每次调用init_process_group()时,目标文件不存在或是空文件。
import torch.distributed as dist
# rank should always be specified
dist.init_process_group(backend, init_method='file:///mnt/nfs/sharedfile', world_size=4, rank=args.rank)环境变量初始化方法
该方法会从环境变量中读取配置,允许用户完全自定义信息的获取方式。需要设置的环境变量包括:
- MASTER_PORT- 必填;必须是 rank 0 机器上的空闲端口
- MASTER_ADDR- 必填(rank 0 除外);rank 0 节点的地址
- WORLD_SIZE- 必填;可以在此处设置,也可以在初始化函数调用时设置
- RANK- 必填;可以在此处设置,也可以在初始化函数调用时设置
rank 为 0 的机器将用于建立所有连接。
这是默认的初始化方法,意味着无需指定 init_method(或可设为 env://)。
初始化后操作
运行 torch.distributed.init_process_group() 后,即可使用以下函数。要检查进程组是否已完成初始化,请调用 torch.distributed.is_initialized()。
class torch
.distributed.Backend(name)一个类似枚举的后端类。
可用后端类型:GLOO、NCCL、UCC、MPI、XCCL 以及其他已注册的后端。
该类的值为小写字符串,例如 "gloo"。可以通过属性访问,例如 Backend.NCCL。
此类可直接调用来解析字符串,例如 Backend(backend_str) 会检查 backend_str 是否有效,若有效则返回解析后的小写字符串。它也接受大写字符串,例如 Backend("GLOO") 会返回 "gloo"。
注意:条目 Backend.UNDEFINED 存在但仅用作某些字段的初始值。用户既不应直接使用它,也不应假定其存在。
CLASSMETHOD register_backend(name, func, extended_api=False, devices=None)使用给定的名称和实例化函数注册一个新的后端。
这个类方法被第三方 ProcessGroup 扩展用于注册新的后端。
参数
- name (str)–- ProcessGroup扩展的后端名称。它应该与- init_process_group()中的名称匹配。
- func (function)– 实例化后端的函数处理程序。该函数应在后端扩展中实现,并接受四个参数,包括- store、- rank、- world_size和- timeout。
- extended_api ([bool], 可选)– 后端是否支持扩展参数结构。默认值:- False。如果设置为- True,后端将获得一个- c10d::DistributedBackendOptions实例,以及一个由后端实现定义的进程组选项对象。
- device (str 或 str 列表, 可选)– 该后端支持的设备类型,例如 “cpu”、“cuda” 等。如果为 None,则假定同时支持 “cpu” 和 “cuda”。
注意:对第三方后端的支持目前处于实验阶段,可能会发生变化。
torch.distributed.get_backend(group=None)返回给定进程组的后端。
参数
- group (ProcessGroup, 可选)– 要操作的进程组。默认为通用的主进程组。如果指定了其他特定组,调用进程必须是该- group的成员。
返回值:以小写字符串形式返回给定进程组的后端。
返回类型:Backend
torch.distributed.get_rank(group=None)返回当前进程在指定group中的排名,若无指定则返回默认值。
排名是分布式进程组中分配给每个进程的唯一标识符。这些排名始终是从0到world_size的连续整数。
参数
- group (ProcessGroup, 可选)– 要操作的进程组。如果为None,则使用默认进程组。
返回值:进程组的排名
- 如果不在该组中,则返回-1
返回类型:int
torch.distributed.get_world_size(group=None)返回当前进程组中的进程数量。
参数
- group (ProcessGroup, 可选)– 要操作的进程组。如果为None,则使用默认进程组。
返回值:进程组的全局大小
如果不在该组中,则返回-1
返回类型:int
关闭处理
在程序退出时,通过调用destroy_process_group()来清理资源非常重要。
推荐遵循的最简单模式是:在训练脚本中不再需要通信的地方(通常是在main()函数末尾附近),通过调用destroy_process_group()并保持group参数为默认值None,来销毁所有进程组和后端。每个训练器进程应该调用一次,而不是在外部的进程启动器层面调用。
如果在超时时间内,某个进程组(pg)中的所有rank都没有调用destroy_process_group(),特别是当应用中存在多个进程组时(例如用于N维并行的情况),可能会导致程序退出时挂起。这是因为ProcessGroupNCCL的析构函数会调用ncclCommAbort,而这个调用必须是集体操作,但如果由Python的垃圾回收器触发ProcessGroupNCCL析构函数的调用顺序是不确定的。显式调用destroy_process_group()可以确保所有rank以一致的顺序调用ncclCommAbort,并避免在ProcessGroupNCCL析构期间调用ncclCommAbort。
重新初始化
destroy_process_group 也可用于销毁单个进程组。一个典型应用场景是容错训练,其中进程组可能在运行时被销毁后重新初始化。这种情况下,关键是在调用销毁操作之后、重新初始化之前,通过非torch.distributed原语的其他方式同步训练器进程。由于实现此类同步的复杂性,该行为目前处于未支持/未测试状态,属于已知问题。若此场景对您造成阻碍,请提交GitHub issue或RFC。
组
默认情况下,集合操作作用于默认组(也称为全局组),并要求所有进程都参与分布式函数调用。然而,某些工作负载可能受益于更细粒度的通信。这正是分布式组发挥作用的地方。new_group() 函数可用于创建包含任意进程子集的新组。该函数返回一个不透明的组句柄,可作为 group 参数传递给所有集合操作(集合操作是指那些以特定编程模式交换信息的分布式函数)。
torch.distributed.new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None, device_id=None)创建一个新的分布式进程组。
该函数要求主进程组中的所有进程(即参与分布式作业的所有进程)都必须进入此函数,即使它们不会成为该组的成员。此外,所有进程必须以相同的顺序创建进程组。
警告:安全并发使用规范:
当使用NCCL后端的多进程组时,用户必须确保所有进程间集合操作的执行顺序全局一致。
如果单个进程内的多个线程发起集合操作,需要通过显式同步来确保执行顺序的一致性。
使用torch.distributed异步通信API时,会返回一个工作对象,通信内核会被放入独立的CUDA流中,从而实现通信与计算的重叠。当一个进程组发起一个或多个异步操作后,必须通过调用work.wait()与其他CUDA流同步,才能使用另一个进程组。
参数说明
- ranks (list[int])- 组成员rank列表。若为- None则包含所有rank,默认为- None
- timeout (timedelta, 可选)- 超时设置,详见- init_process_group说明
- backend (str 或 [Backend](https://pytorch.org/docs/stable/data.html#torch.distributed.Backend "torch.distributed.Backend"), 可选)- 使用的后端。根据构建配置可选- gloo或- nccl,默认使用全局组的后端。应传入小写字符串(如- "gloo"),也可通过- Backend属性指定(如- Backend.GLOO)。传入- None时将使用默认进程组的后端
- pg_options (ProcessGroupOptions, 可选)- 进程组配置选项,用于指定特殊参数。例如对- nccl后端可设置- is_high_priority_stream来启用高优先级CUDA流。其他NCCL配置选项参见类型文档
- use_local_synchronization ([bool], 可选)- 在进程组创建结束时执行组内局部屏障。与非成员rank不同,这些rank无需调用API且不参与屏障
- group_desc (str, 可选)- 进程组的描述字符串
- device_id (torch.device, 可选)- 要绑定的特定设备。若指定此参数,- new_group会立即尝试初始化该设备的通信后端
返回值
返回分布式组的句柄,可用于集合调用。若当前rank不在ranks中则返回GroupMember.NON_GROUP_MEMBER。
注意事项
1、use_local_synchronization不兼容MPI后端
2、在大型集群和小型进程组中使用use_local_synchronization=True可能显著提升性能,但需注意这会改变集群行为(非成员rank不参与屏障)
3、当各rank创建多个重叠进程组时,use_local_synchronization=True可能导致死锁。为避免此问题,需确保所有rank遵循相同的全局创建顺序
torch.distributed.get_group_rank(group, global_rank)将全局排名转换为组内排名。
如果 global_rank 不属于 group 的成员,此操作会抛出 RuntimeError。
参数
- group (ProcessGroup)– 用于查找相对排名的进程组。
- global_rank (int)– 要查询的全局排名。
返回值
返回 global_rank 相对于 group 的组内排名
返回类型
int
注意:在默认进程组上调用此函数会返回原值
torch.distributed.get_global_rank(group, group_rank)将组内排名转换为全局排名。
如果 group_rank 不属于该组,将抛出 RuntimeError。
参数
- group (ProcessGroup)– 用于查询全局排名的进程组。
- group_rank ( int )– 需要查询的组内排名。
返回值:group_rank 相对于 group 的全局排名
返回类型:int
注意:在默认进程组上调用此函数将返回原值
torch.distributed.get_process_group_ranks(group)获取与group关联的所有排名。
参数
- group (ProcessGroup)– 要从中获取所有排名的ProcessGroup。
返回值:按组内排名排序的全局排名列表。
返回类型:list [int]
DeviceMesh
DeviceMesh 是一种更高层次的抽象,用于管理进程组(或 NCCL 通信器)。它允许用户轻松创建节点间和节点内的进程组,而无需关心如何为不同的子进程组正确设置 ranks,并帮助轻松管理这些分布式进程组。可以通过 init_device_mesh() 函数创建新的 DeviceMesh,其中 mesh shape 参数用于描述设备拓扑结构。
class torch
.distributed.device_mesh.DeviceMesh(device_type, mesh, *, mesh_dim_names=None, _init_backend=True)DeviceMesh 表示一个设备网格,其中设备的布局可以表示为一个 n 维数组,该 n 维数组的每个值是默认进程组 ranks 的全局 ID。
DeviceMesh 可用于描述集群中设备的布局,并作为集群内设备列表间通信的代理。
DeviceMesh 可用作上下文管理器。
注意:DeviceMesh 遵循 SPMD 编程模型,这意味着相同的 PyTorch Python 程序会在集群中的所有进程/ranks 上运行。因此,用户需要确保描述设备布局的网格数组在所有 ranks 上保持一致。不一致的网格会导致静默挂起。
参数
- device_type (str)– 网格的设备类型。当前支持:“cpu”、“cuda/cuda-like”。
- mesh (ndarray)– 描述设备布局的多维数组或整数张量,其中 ID 是默认进程组的全局 ID。
返回
一个表示设备布局的 DeviceMesh 对象。
返回类型:DeviceMesh
以下程序以 SPMD 方式在每个进程/rank 上运行。在此示例中,我们有 2 台主机,每台主机有 4 个 GPU。
在网格的第一个维度上进行归约操作会跨列 (0, 4), … 和 (3, 7) 进行,在网格的第二个维度上进行归约操作会跨行 (0, 1, 2, 3) 和 (4, 5, 6, 7) 进行。
示例:
>>
>
from torch.distributed.device_mesh import DeviceMesh
>>
>
>
>>
>
# Initialize device mesh as (2, 4) to represent the topology
>>
>
# of cross-host(dim 0), and within-host (dim 1).
>>
> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])static from_group(group, device_type, mesh=None, *, mesh_dim_names=None)基于现有的ProcessGroup或一组ProcessGroup列表,构造指定device_type的DeviceMesh。
构造的设备网格维度数与传入的进程组数量相同。例如:
- 传入单个进程组时,生成1D网格
- 传入2个进程组列表时,生成2D网格
当传入多个进程组时,必须提供mesh和mesh_dim_names参数。进程组的传入顺序决定网格拓扑结构,例如第一个进程组对应DeviceMesh的第0维度。
传入的mesh张量必须满足:
1、维度数与进程组数量相同
2、张量维度顺序与进程组传入顺序一致
参数说明
- group (ProcessGroup* 或 list[ProcessGroup])- 现有进程组或进程组列表
- device_type (str)- 网格设备类型,当前支持:“cpu”、“cuda/cuda-like”。禁止传入带GPU索引的类型(如"cuda:0")
- mesh (torch.Tensor 或 *ArrayLike, 可选)- 描述设备布局的多维数组/整型张量,ID为默认进程组的全局ID。默认为None
- mesh_dim_names (tuple[str], 可选)- 为设备布局数组各维度命名的元组,其长度必须与mesh_shape匹配,且每个字符串必须唯一。默认为None
返回值:表示设备布局的DeviceMesh对象
返回类型:DeviceMesh
get_all_groups()返回所有网格维度的进程组列表。
返回值:一个包含 ProcessGroup 对象的列表。
返回类型:list [torch.distributed.distributed_c10d.ProcessGroup]
get_coordinate()返回当前秩相对于网格所有维度的相对索引。如果该秩不属于网格,则返回 None。
返回类型:Optional[list [int ]]
get_group(mesh_dim=None)返回由mesh_dim指定的单个ProcessGroup。如果未指定mesh_dim且DeviceMesh是一维的,则返回该mesh中唯一的ProcessGroup。
参数
- mesh_dim (str/python:int, 可选)- 可以是mesh维度的名称或索引
- None. (默认值为)-
返回
一个ProcessGroup对象。
返回类型:ProcessGroup
get_local_rank(mesh_dim=None)返回给定设备网格维度(mesh_dim)的本地秩。
参数
- mesh_dim (str/python:int, 可选)- 可以是网格维度的名称或索引
- None. (网格维度的默认值)-
返回值:表示本地秩的整数值。
返回类型:int
以下程序以SPMD方式在每个进程/秩上运行。本例中,我们使用2台主机,每台主机配备4个GPU。
在秩0、1、2、3上调用mesh_2d.get_local_rank(mesh_dim=0)将返回0;在秩4、5、6、7上调用mesh_2d.get_local_rank(mesh_dim=0)将返回1;在秩0、4上调用mesh_2d.get_local_rank(mesh_dim=1)将返回0;在秩1、5上调用mesh_2d.get_local_rank(mesh_dim=1)将返回1。
在秩2、6上调用mesh_2d.get_local_rank(mesh_dim=1)将返回2;在秩3、7上调用mesh_2d.get_local_rank(mesh_dim=1)将返回3。
示例:
>>
>
from torch.distributed.device_mesh import DeviceMesh
>>
>
>
>>
>
# Initialize device mesh as (2, 4) to represent the topology
>>
>
# of cross-host(dim 0), and within-host (dim 1).
>>
> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])get_rank()返回当前全局排名。
返回值类型:int
点对点通信
torch.distributed.send(tensor, dst=None, group=None, tag=0, group_dst=None)同步发送张量。
警告:NCCL后端不支持tag参数。
参数说明
- tensor ( Tensor )- 要发送的张量。
- dst ( int )- 全局进程组中的目标rank(不受- group参数影响)。目标rank不应与当前进程的rank相同。
- group (ProcessGroup, 可选)- 要操作的工作进程组。如果为None,将使用默认进程组。
- tag ( int , 可选)- 用于匹配远程接收操作的标记
- group_dst ( int , 可选)- 在- group中的目标rank。不能同时指定- dst和- group_dst参数。
torch.distributed.recv(tensor, src=None, group=None, tag=0, group_src=None)同步接收一个张量。
警告:NCCL后端不支持tag参数。
参数
- tensor ( Tensor )- 用于填充接收数据的张量。
- src ( int , 可选)- 全局进程组中的源rank(不受- group参数影响)。若未指定,将从任意进程接收数据。
- group (ProcessGroup, 可选)- 要操作的工作进程组。若为None,则使用默认进程组。
- tag ( int , 可选)- 用于匹配远程发送操作的标签
- group_src ( int , 可选)- 目标进程在- group中的rank。不可同时指定- src和- group_src。
返回值:发送方rank
- 若不属于该进程组,则返回-1
返回类型:int
在使用时会返回分布式请求对象。通常不建议手动创建这类对象,因此其具体类型不作规定,但保证支持以下两种方法:
- is_completed()- 若操作完成则返回True
- wait()- 阻塞进程直至操作完成
is_completed()方法一旦返回结果,其返回值必定为True。
torch.distributed.isend(tensor, dst=None, group=None, tag=0, group_dst=None)异步发送张量。
警告:在请求完成前修改 tensor 会导致未定义行为。
警告:NCCL 后端不支持 tag 参数。
与阻塞式的 send 不同,isend 允许 src == dst 排名,即支持向自身发送。
参数
- tensor (Tensor)– 待发送的张量。
- dst (int)– 全局进程组中的目标排名(不受- group参数影响)。
- group (ProcessGroup, 可选)– 操作的目标进程组。若为 None,则使用默认进程组。
- tag (int, 可选)– 用于匹配远程 recv 的标记。
- group_dst (int, 可选)–- group中的目标排名。不可同时指定- dst和- group_dst。
返回
一个分布式请求对象。若不属于该进程组则返回 None。
返回类型
Optional[Work]
torch.distributed.irecv(tensor, src=None, group=None, tag=0, group_src=None)异步接收一个张量。
警告:NCCL后端不支持tag参数。
与阻塞式的recv不同,irecv允许src等于dst的rank,即可以从自身接收数据。
参数
- tensor ( Tensor )– 用于填充接收数据的张量。
- src ( int , 可选)– 全局进程组中的源rank(不受- group参数影响)。如果未指定,将从任意进程接收数据。
- group (ProcessGroup, 可选)– 要操作的工作进程组。如果为None,则使用默认进程组。
- tag ( int , 可选)– 用于匹配远程发送的接收标记
- group_src ( int , 可选)– 在- group中的目标rank。不能同时指定- src和- group_src。
返回值:一个分布式请求对象。
如果不在该进程组中,则返回None
返回类型:Optional[Work]
torch.distributed.send_object_list(object_list, dst=None, group=None, device=None, group_dst=None)同步发送 object_list 中可序列化的对象。
与 send() 类似,但可以传递 Python 对象。
注意,object_list 中的所有对象必须可序列化才能发送。
参数
- object_list (List[Any])– 要发送的输入对象列表。每个对象必须可序列化。接收方必须提供大小相等的列表。
- dst (int)– 发送- object_list的目标 rank。目标 rank 基于全局进程组(与- group参数无关)。
- group (Optional[ProcessGroup])– (可选)要操作的进程组。如果为 None,则使用默认进程组。默认为- None。
- device (torch.device, optional)– 如果不为 None,对象会被序列化并转换为张量,发送前移动到- device。默认为- None。
- group_dst (int, optional)–- group上的目标 rank。必须指定- dst或- group_dst之一,但不能同时指定。
返回
None。
注意:对于基于 NCCL 的进程组,对象的内部张量表示必须在通信前移动到 GPU 设备。此时使用的设备由 torch.cuda.current_device() 给出,用户需确保通过 torch.cuda.set_device() 设置,使每个 rank 拥有独立的 GPU。
警告:send_object_list() 隐式使用 pickle 模块,已知其不安全。恶意构造的 pickle 数据可能在反序列化时执行任意代码。仅对可信数据调用此函数。
警告:使用 GPU 张量调用 send_object_list() 支持不佳且效率低下,因为张量会被序列化,导致 GPU-CPU 传输。建议改用 send()。
示例:
>>
>
# Note: Process group initialization omitted on each rank.
>>
>
import torch.distributed as dist
>>
>
# Assumes backend is not NCCL
>>
> device = torch.device("cpu")
>>
>
if dist.get_rank() == 0:
>>
>
# Assumes world_size of 2、>> objects = ["foo", 12, {1: 2}] # any picklable object
>>
> dist.send_object_list(objects, dst=1, device=device)
>>
>
else:
>>
> objects = [None, None, None]
>>
> dist.recv_object_list(objects, src=0, device=device)
>>
> objects
['foo', 12, {
1: 2
}]torch.distributed.recv_object_list(object_list, src=None, group=None, device=None, group_src=None)同步接收object_list中的可序列化对象。
类似于recv(),但可以接收Python对象。
参数
- object_list (List[Any])- 用于接收对象的列表。必须提供一个与发送列表大小相等的尺寸列表。
- src (int, 可选)- 接收- object_list的源进程排名。源排名基于全局进程组(无论- group参数如何)。如果设置为None,将从任意排名接收。默认为- None。
- group (Optional[ProcessGroup])- (ProcessGroup, 可选): 要操作的进程组。如果为None,将使用默认进程组。默认为- None。
- device (torch.device- , 可选)- 如果不为None,则在此设备上接收。默认为- None。
- group_src (int, 可选)-- group上的目标排名。不能同时指定- src和- group_src。
返回
发送方排名。如果排名不属于该组,则为-1。如果排名属于该组,object_list将包含来自src排名的发送对象。
注意:对于基于NCCL的进程组,对象的内部张量表示必须在通信之前移动到GPU设备。在这种情况下,使用的设备由torch.cuda.current_device()给出,用户有责任通过torch.cuda.set_device()确保每个排名都有一个单独的GPU。
警告:recv_object_list()隐式使用pickle模块,已知其不安全。可能构造恶意的pickle数据,在反序列化期间执行任意代码。仅对可信数据调用此函数。
警告:使用GPU张量调用recv_object_list()不受良好支持且效率低下,因为张量会被pickle,导致GPU-CPU传输。请考虑改用recv()。
示例:
>>
>
# Note: Process group initialization omitted on each rank.
>>
>
import torch.distributed as dist
>>
>
# Assumes backend is not NCCL
>>
> device = torch.device("cpu")
>>
>
if dist.get_rank() == 0:
>>
>
# Assumes world_size of 2、>> objects = ["foo", 12, {1: 2}] # any picklable object
>>
> dist.send_object_list(objects, dst=1, device=device)
>>
>
else:
>>
> objects = [None, None, None]
>>
> dist.recv_object_list(objects, src=0, device=device)
>>
> objects
['foo', 12, {
1: 2
}]torch.distributed.batch_isend_irecv(p2p_op_list)异步发送或接收一批张量并返回请求列表。
处理 p2p_op_list 中的每个操作,并返回对应的请求。当前支持 NCCL、Gloo 和 UCC 后端。
参数
- p2p_op_list (list[torch.distributed.distributed_c10d.P2POp])– 点对点操作列表(每个操作的类型为- torch.distributed.P2POp)。列表中的 isend/irecv 顺序很重要,需要与远程端的对应 isend/irecv 匹配。
返回
通过调用 op_list 中的对应操作返回的分布式请求对象列表。
返回类型
list [torch.distributed.distributed_c10d.Work]
示例:
>>
> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
>>
> recv_tensor = torch.randn(2, dtype=torch.float32)
>>
> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
>>
> recv_op = dist.P2POp(
... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
... )
>>
> reqs = batch_isend_irecv([send_op, recv_op])
>>
>
for req in reqs:
>>
> req.wait()
>>
> recv_tensor
tensor([2, 3]) # Rank 0
tensor([0, 1]) # Rank 1注意:当此API与NCCL PG后端一起使用时,用户必须通过torch.cuda.set_device设置当前GPU设备,否则会导致意外的挂起问题。
此外,如果此API是传入dist.P2POp的group中的第一个集合通信调用,则该group的所有进程都必须参与此次API调用;否则行为将是未定义的。如果此API调用不是group中的第一个集合通信操作,则允许仅涉及group中部分进程的批量P2P操作。
class torch
.distributed.P2POp(op, tensor, peer=None, group=None, tag=0, group_peer=None)一个用于为batch_isend_irecv构建点对点操作的类。
该类构建P2P操作类型、通信缓冲区、对等节点秩、进程组和标签。此类的实例将被传递给batch_isend_irecv以进行点对点通信。
参数
- op (Callable)– 用于向对等进程发送或接收数据的函数。
op的类型为torch.distributed.isend或torch.distributed.irecv。
- tensor ( Tensor )– 要发送或接收的张量。
- peer ( int , optional)– 目标或源秩。
- group (ProcessGroup, optional)– 要操作的进程组。如果为None,将使用默认进程组。
- tag ( int , optional)– 用于匹配发送与接收的标签。
- group_peer ( int , optional)– 目标或源秩。
同步与异步集合操作
每个集合操作函数都支持以下两种操作模式,具体取决于传入的async_op标志设置:
同步操作 - 默认模式,当async_op设为False时生效。函数返回时,可以确保集合操作已执行完成。对于CUDA操作而言,由于CUDA操作本身是异步的,此时不能保证CUDA操作已完成。对于CPU集合操作,后续使用该操作输出的函数调用将按预期工作。对于CUDA集合操作,在同一个CUDA流中使用输出的函数调用将按预期工作。若在不同流中运行,用户需自行处理同步问题。有关CUDA语义(如流同步)的详细信息,请参阅CUDA语义。下方脚本展示了CPU与CUDA操作在这些语义上的差异示例。
异步操作 - 当async_op设为True时生效。集合操作函数会返回一个分布式请求对象。通常无需手动创建该对象,它保证支持以下方法:
- is_completed()- 对于CPU集合操作,完成时返回- True。对于CUDA操作,当操作成功加入CUDA流且输出可在默认流中使用而无需额外同步时返回- True
- wait()- 对于CPU集合操作,将阻塞进程直至操作完成。对于CUDA集合操作,将阻塞当前活跃的CUDA流直至操作完成(但不会阻塞CPU)
- get_future()- 返回- torch._C.Future对象。支持NCCL后端,也支持GLOO和MPI后端的大多数操作(点对点操作除外)
 注意:随着我们持续采用Future并合并API,- get_future()调用可能会变得冗余
示例
以下代码可作为使用分布式集合操作时CUDA操作语义的参考,展示了在不同CUDA流中使用集合操作输出时需要显式同步的情况:
# Code runs on each rank.
dist.init_process_group("nccl", rank=rank, world_size=2)
output = torch.tensor([rank]).cuda(rank)
s = torch.cuda.Stream()
handle = dist.all_reduce(output, async_op=True)
# Wait ensures the operation is enqueued, but not necessarily complete.
handle.wait()
# Using result on non-default stream. with torch.cuda.stream(s):
s.wait_stream(torch.cuda.default_stream())
output.add_(100) if rank == 0:
# if the explicit call to wait_stream was omitted, the output below will be # non-deterministically 1 or 101, depending on whether the allreduce overwrote
# the value after the add completed.
print(output)集合函数
torch.distributed.broadcast(tensor, src=None, group=None, async_op=False, group_src=None)将张量广播到整个进程组。
所有参与集体通信的进程中,tensor 必须具有相同的元素数量。
参数说明
- tensor ( Tensor )- 如果当前进程是源进程(- src),则作为待发送数据;否则作为接收数据的存储张量。
- src ( int )- 全局进程组中的源进程排名(不受- group参数影响)。
- group (ProcessGroup, 可选)- 操作的进程组。若为None,则使用默认进程组。
- async_op ([bool], 可选)- 是否作为异步操作执行。
- group_src ( int )- 指定- group内的源进程排名。必须且只能指定- group_src或- src中的一个。
返回值
- 若async_op设为True,返回异步操作句柄。
- 若非异步操作或不属于该进程组,返回None。
torch.distributed.broadcast_object_list(object_list, src=None, group=None, device=None, group_src=None)将 object_list 中的可序列化对象广播到整个组。
类似于 broadcast(),但可以传入 Python 对象。
注意,object_list 中的所有对象必须可序列化才能被广播。
参数
- object_list (List[Any])– 要广播的输入对象列表。每个对象必须可序列化。只有- src进程上的对象会被广播,但每个进程必须提供大小相同的列表。
- src ( int )– 广播- object_list的源进程号。源进程号基于全局进程组(与- group参数无关)。
- group (Optional[ProcessGroup])– (可选)要操作的进程组。如果为 None,则使用默认进程组。默认为- None。
- device (torch.device- , optional)– 如果非 None,对象会被序列化并转换为张量,广播前移动到- device。默认为- None。
- group_src ( int )–- group上的源进程号。不能同时指定- group_src和- src。
返回
None。如果当前进程属于该组,object_list 将包含从 src 进程广播的对象。
注意:对于基于 NCCL 的进程组,对象的内部张量表示必须在通信前移动到 GPU 设备。此时使用的设备由 torch.cuda.current_device() 给出,用户需确保通过 torch.cuda.set_device() 设置每个进程有独立的 GPU。
注意:此 API 与 broadcast() 略有不同,因为它不提供 async_op 句柄,因此是阻塞调用。
警告:broadcast_object_list() 隐式使用 pickle 模块,已知其不安全。恶意构造的 pickle 数据可能在反序列化时执行任意代码。仅对可信数据调用此函数。
警告:使用 GPU 张量调用 broadcast_object_list() 支持不佳且效率低下,因为张量会被序列化导致 GPU-CPU 传输。建议改用 broadcast()。
示例:
>>
>
# Note: Process group initialization omitted on each rank.
>>
>
import torch.distributed as dist
>>
>
if dist.get_rank() == 0:
>>
>
# Assumes world_size of 3、>> objects = ["foo", 12, {1: 2}] # any picklable object
>>
>
else:
>>
> objects = [None, None, None]
>>
>
# Assumes backend is not NCCL
>>
> device = torch.device("cpu")
>>
> dist.broadcast_object_list(objects, src=0, device=device)
>>
> objects
['foo', 12, {
1: 2
}]torch.distributed.all_reduce(tensor, op=<RedOpType.SUM: 0>
  , group=None, async_op=False)以所有机器都能获取最终结果的方式对张量数据进行归约操作。
调用后,所有进程中的 tensor 将保持二进制级别的一致性。
支持复数张量。
参数
- tensor (Tensor)- 集合操作的输入和输出张量。该函数会就地修改张量。
- op (可选)- 从- torch.distributed.ReduceOp枚举中选择的操作类型。指定用于逐元素归约的运算方式。
- group (ProcessGroup, 可选)- 要操作的工作进程组。若为 None,则使用默认进程组。
- async_op (bool, 可选)- 是否将此操作设为异步操作。
返回
- 若 async_op 设为 True,返回异步操作句柄。
- 若非异步操作或不属于该进程组,则返回 None。
示例:
>>
>
# All tensors below are of torch.int64 type.
>>
>
# We have 2 process groups, 2 ranks.
>>
> device = torch.device(f"cuda:{rank
}")
>>
> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
>>
> tensor
tensor([1, 2], device='cuda:0') # Rank 0
tensor([3, 4], device='cuda:1') # Rank 1
>>
> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>
> tensor
tensor([4, 6], device='cuda:0') # Rank 0
tensor([4, 6], device='cuda:1') # Rank 1>>
>
# All tensors below are of torch.cfloat type.
>>
>
# We have 2 process groups, 2 ranks.
>>
> tensor = torch.tensor(
... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
... ) + 2 * rank * (1 + 1j)
>>
> tensor
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
>>
> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>
> tensor
tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0
tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1torch.distributed.reduce(tensor, dst=None, op=<RedOpType.SUM: 0>
  , group=None, async_op=False, group_dst=None)在所有机器间对张量数据进行归约操作。
只有排名为 dst 的进程会接收到最终结果。
参数
- tensor ( Tensor )– 集合操作的输入和输出张量。该函数会就地修改数据。
- dst ( int )– 全局进程组中的目标排名(不受- group参数影响)
- op (可选)– 从- torch.distributed.ReduceOp枚举中选择的值。指定用于逐元素归约的操作类型。
- group (ProcessGroup, 可选)– 要操作的目标进程组。若为 None,则使用默认进程组。
- async_op ([bool], 可选)– 是否将此操作设为异步操作
- group_dst ( int )– 在- group上的目标排名。必须指定- group_dst和- dst中的一个,但不能同时指定两者。
返回值:若 async_op 设为 True,则返回异步操作句柄。
若未设置 async_op 或不属于该进程组,则返回 None
torch.distributed.all_gather(tensor_list, tensor, group=None, async_op=False)从整个进程组中收集张量到列表中。
支持复杂且大小不一的张量。
参数
- tensor_list (list[Tensor])- 输出列表。该列表应包含正确尺寸的张量,用于集合通信的输出。支持大小不一的张量。
- tensor (Tensor)- 从当前进程广播的张量。
- group (ProcessGroup, 可选)- 要操作的进程组。如果为None,则使用默认进程组。
- async_op ([bool], 可选)- 该操作是否应为异步操作
返回
如果async_op设置为True,则返回异步工作句柄。
如果不设置async_op或不属于该进程组,则返回None
示例:
>>
>
# All tensors below are of torch.int64 dtype.
>>
>
# We have 2 process groups, 2 ranks.
>>
> device = torch.device(f"cuda:{rank
}")
>>
> tensor_list = [
... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)
... ]
>>
> tensor_list
[tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0
[tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1
>>
> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
>>
> tensor
tensor([1, 2], device='cuda:0') # Rank 0
tensor([3, 4], device='cuda:1') # Rank 1
>>
> dist.all_gather(tensor_list, tensor)
>>
> tensor_list
[tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0
[tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1>>
>
# All tensors below are of torch.cfloat dtype.
>>
>
# We have 2 process groups, 2 ranks.
>>
> tensor_list = [
... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)
... ]
>>
> tensor_list
[tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0
[tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1
>>
> tensor = torch.tensor(
... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
... ) + 2 * rank * (1 + 1j)
>>
> tensor
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
>>
> dist.all_gather(tensor_list, tensor)
>>
> tensor_list
[tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0
[tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1torch.distributed.all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False)从所有进程收集张量并合并为一个输出张量。
此函数要求每个进程上的所有张量大小相同。
参数
- output_tensor (Tensor)- 用于容纳来自所有进程张量元素的输出张量。其尺寸必须正确设置为以下形式之一:
(i) 沿主维度拼接所有输入张量;关于"拼接"的定义,请参阅 torch.cat();
(ii) 沿主维度堆叠所有输入张量;关于"堆叠"的定义,请参阅 torch.stack()。
下方示例可以更清楚地说明支持的输出形式。
- input_tensor (Tensor)- 从当前进程收集的输入张量。
与 all_gather API 不同,本 API 要求所有进程的输入张量必须具有相同大小。
- group (ProcessGroup, 可选)- 要操作的工作进程组。如果为 None,则使用默认进程组。
- async_op ([bool], 可选)- 是否将此操作设为异步操作
返回
如果 async_op 设为 True,则返回异步操作句柄。
如果不设 async_op 或不属于该进程组,则返回 None
示例:
>>
>
# All tensors below are of torch.int64 dtype and on CUDA devices.
>>
>
# We have two ranks.
>>
> device = torch.device(f"cuda:{rank
}")
>>
> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
>>
> tensor_in
tensor([1, 2], device='cuda:0') # Rank 0
tensor([3, 4], device='cuda:1') # Rank 1
>>
>
# Output in concatenation form
>>
> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device)
>>
> dist.all_gather_into_tensor(tensor_out, tensor_in)
>>
> tensor_out
tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
>>
>
# Output in stack form
>>
> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device)
>>
> dist.all_gather_into_tensor(tensor_out2, tensor_in)
>>
> tensor_out2
tensor([[1, 2], [3, 4]], device='cuda:0') # Rank 0
tensor([[1, 2], [3, 4]], device='cuda:1') # Rank 1警告:Gloo 后端不支持此 API。
torch.distributed.all_gather_object(object_list, obj, group=None)将整个组中的可pickle对象收集到一个列表中。
类似于 all_gather(),但可以传递Python对象。
注意:对象必须是可pickle的才能被收集。
参数
- object_list (list[Any])– 输出列表。其大小应正确设置为该集合操作的组大小,并将包含输出结果。
- obj (Any)– 从当前进程广播的可pickle的Python对象。
- group (ProcessGroup, 可选)– 要操作的工作进程组。如果为None,则使用默认进程组。默认为- None。
返回
无。如果调用rank属于该组,集合操作的输出将填充到输入的object_list中。如果调用rank不属于该组,传入的object_list将保持不变。
注意:请注意此API与 all_gather() 集合操作略有不同,因为它不提供async_op句柄,因此将是一个阻塞调用。
注意:对于基于NCCL的进程组,对象的内部张量表示必须在通信发生前移动到GPU设备。这种情况下,使用的设备由torch.cuda.current_device()给出,用户有责任通过torch.cuda.set_device()确保每个rank都有独立的GPU。
警告:all_gather_object() 隐式使用pickle模块,已知该模块不安全。可能构造恶意的pickle数据,在反序列化时执行任意代码。仅对可信数据调用此函数。
警告:使用GPU张量调用 all_gather_object() 支持不佳且效率低下,因为张量需要被pickle会导致GPU-CPU传输。请考虑改用 all_gather()。
示例:
>>
>
# Note: Process group initialization omitted on each rank.
>>
>
import torch.distributed as dist
>>
>
# Assumes world_size of 3、>>gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>
> output = [None for _ in gather_objects]
>>
> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>
> output
['foo', 12, {
1: 2
}]torch.distributed.gather(tensor, gather_list=None, dst=None, group=None, async_op=False, group_dst=None)将多个进程中的张量列表收集到单个进程中。
此函数要求每个进程中的所有张量大小必须相同。
参数
- tensor ( Tensor )– 输入张量。
- gather_list (list[Tensor ], 可选)– 用于收集数据的适当大小且尺寸相同的张量列表(默认为None,必须在目标rank上指定)
- dst ( int , 可选)– 全局进程组中的目标rank(不受- group参数影响)。(如果- dst和- group_dst均为None,则默认为全局rank 0)
- group (ProcessGroup, 可选)– 要操作的工作进程组。如果为None,则使用默认进程组。
- async_op ([bool], 可选)– 此操作是否应为异步操作
- group_dst ( int , 可选)–- group中的目标rank。不允许同时指定- dst和- group_dst
返回值:如果async_op设置为True,则返回异步工作句柄。
如果未设置async_op或不属于该进程组,则返回None
注意:gather_list中的所有张量必须具有相同的大小。
示例:
>>
>
# We have 2 process groups, 2 ranks.
>>
> tensor_size = 2
>>
> device = torch.device(f'cuda:{rank
}')
>>
> tensor = torch.ones(tensor_size, device=device) + rank
>>
>
if dist.get_rank() == 0:
>>
> gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)]
>>
>
else:
>>
> gather_list = None
>>
> dist.gather(tensor, gather_list, dst=0)
>>
>
# Rank 0 gets gathered data.
>>
> gather_list
[tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0
None # Rank 1torch.distributed.gather_object(obj, object_gather_list=None, dst=None, group=None, group_dst=None)从整个进程组中收集可序列化对象到单个进程。
功能类似于 gather(),但支持传递Python对象。注意:待收集的对象必须可序列化。
参数
- obj (Any)– 输入对象,必须可序列化。
- object_gather_list (list[Any])– 输出列表。在目标- dst进程上,该列表需预先分配为进程组大小的空间以存储结果。非目标进程上必须设为- None(默认值为- None)。
- dst ( int , optional)– 全局进程组中的目标进程编号(不受- group参数影响)。若- dst和- group_dst均为None,则默认为全局0号进程。
- group (Optional[ProcessGroup])– 操作的目标进程组。若为None则使用默认进程组(默认值为- None)。
- group_dst ( int , optional)– 指定- group参数对应进程组中的目标进程编号。不可同时指定- dst和- group_dst。
返回值
无。在目标dst进程上,object_gather_list将包含集合操作的结果。
注意
本API与常规gather操作略有不同:不提供async_op异步句柄,因此是阻塞调用。
注意
对于基于NCCL的进程组,对象内部的张量表示必须在通信前移至GPU设备。此时设备由torch.cuda.current_device()决定,用户需通过torch.cuda.set_device()确保每个进程独占GPU。
警告
gather_object()隐式使用pickle模块,该模块存在安全隐患。恶意构造的pickle数据可能在反序列化时执行任意代码。请仅对可信数据调用此函数。
警告
对GPU张量调用gather_object()支持不佳且效率低下,因为序列化会引发GPU-CPU传输。建议改用gather()。
示例:
>>
>
# Note: Process group initialization omitted on each rank.
>>
>
import torch.distributed as dist
>>
>
# Assumes world_size of 3、>>gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>
> output = [None for _ in gather_objects]
>>
> dist.gather_object(
... gather_objects[dist.get_rank()],
... output if dist.get_rank() == 0 else None,
... dst=0
... )
>>
>
# On rank 0
>>
> output
['foo', 12, {
1: 2
}]torch.distributed.scatter(tensor, scatter_list=None, src=None, group=None, async_op=False, group_src=None)将一组张量分散到进程组中的所有进程。
每个进程将准确接收一个张量,并将其数据存储在 tensor 参数中。
支持复数张量。
参数
- tensor ( Tensor )– 输出张量。
- scatter_list (list[Tensor ])– 要分散的张量列表(默认为 None,必须在源 rank 上指定)
- src ( int )– 全局进程组中的源 rank(不受- group参数影响)。
(如果 src 和 group_src 均为 None,则默认为全局 rank 0)
- group (ProcessGroup, 可选)– 要操作的进程组。如果为 None,则使用默认进程组。
- async_op ([bool], 可选)– 此操作是否应为异步操作
- group_src ( int , 可选)–- group中的源 rank。不能同时指定- src和- group_src
返回
如果 async_op 设置为 True,则返回异步工作句柄。
如果不为 async_op 或不属于该组,则返回 None
注意:请注意,scatter_list 中的所有张量必须具有相同的大小。
示例:
>>
>
# Note: Process group initialization omitted on each rank.
>>
>
import torch.distributed as dist
>>
> tensor_size = 2
>>
> device = torch.device(f'cuda:{rank
}')
>>
> output_tensor = torch.zeros(tensor_size, device=device)
>>
>
if dist.get_rank() == 0:
>>
>
# Assumes world_size of 2、>> # Only tensors, all of which must be the same size.
>>
> t_ones = torch.ones(tensor_size, device=device)
>>
> t_fives = torch.ones(tensor_size, device=device) * 5
>>
> scatter_list = [t_ones, t_fives]
>>
>
else:
>>
> scatter_list = None
>>
> dist.scatter(output_tensor, scatter_list, src=0)
>>
>
# Rank i gets scatter_list[i].
>>
> output_tensor
tensor([1., 1.], device='cuda:0') # Rank 0
tensor([5., 5.], device='cuda:1') # Rank 1torch.distributed.scatter_object_list(scatter_object_output_list, scatter_object_input_list=None, src=None, group=None, group_src=None)将 scatter_object_input_list 中的可序列化对象分发到整个组中。
类似于 scatter(),但可以传递 Python 对象。在每个 rank 上,分发的对象将作为 scatter_object_output_list 的第一个元素存储。注意,scatter_object_input_list 中的所有对象必须可序列化才能被分发。
参数
- scatter_object_output_list (List[Any])– 非空列表,其第一个元素将存储分发到当前 rank 的对象。
- scatter_object_input_list (List[Any], optional)– 要分发的输入对象列表。每个对象必须可序列化。只有- srcrank 上的对象会被分发,非 src rank 可以传入- None。
- src ( int )– 分发- scatter_object_input_list的源 rank。源 rank 基于全局进程组(与- group参数无关)。(如果- src和- group_src均为 None,则默认为全局 rank 0)
- group (Optional[ProcessGroup])– (ProcessGroup,可选):要操作的进程组。如果为 None,则使用默认进程组。默认为- None。
- group_src ( int , optional)–- group上的源 rank。不能同时指定- src和- group_src。
返回值
None。如果当前 rank 属于该组,scatter_object_output_list 的第一个元素将被设置为分发到该 rank 的对象。
注意:请注意此 API 与 scatter 集合操作略有不同,因为它不提供 async_op 句柄,因此是一个阻塞调用。
警告:scatter_object_list() 隐式使用了 pickle 模块,已知该模块不安全。可能构造恶意的 pickle 数据,在反序列化时执行任意代码。请仅对可信数据调用此函数。
警告:使用 GPU 张量调用 scatter_object_list() 支持不佳且效率低下,因为张量需要序列化会导致 GPU-CPU 传输。请考虑改用 scatter()。
示例:
>>
>
# Note: Process group initialization omitted on each rank.
>>
>
import torch.distributed as dist
>>
>
if dist.get_rank() == 0:
>>
>
# Assumes world_size of 3、>> objects = ["foo", 12, {1: 2}] # any picklable object
>>
>
else:
>>
>
# Can be any list on non-src ranks, elements are not used.
>>
> objects = [None, None, None]
>>
> output_list = [None]
>>
> dist.scatter_object_list(output_list, objects, src=0)
>>
>
# Rank i gets objects[i]. For example, on rank 2:
>>
> output_list
[{
1: 2
}]torch.distributed.reduce_scatter(output, input_list, op=<RedOpType.SUM: 0>
  , group=None, async_op=False)将一组张量进行归约后分散到进程组中的所有进程。
参数
- output ( Tensor )– 输出张量。
- input_list (list[Tensor ])– 待归约和分散的张量列表。
- op (可选)– 从- torch.distributed.ReduceOp枚举中选择的值。指定用于逐元素归约的操作。
- group (ProcessGroup, 可选)– 要操作的进程组。如果为 None,则使用默认进程组。
- async_op ([bool], 可选)– 此操作是否应为异步操作。
返回值:如果 async_op 设为 True,则返回异步工作句柄。
如果不为异步操作或不属于该进程组,则返回 None。
torch.distributed.reduce_scatter_tensor(output, input, op=<RedOpType.SUM: 0>
  , group=None, async_op=False)对张量进行归约操作后,将其分散到组内所有进程中。
参数
- output (Tensor)- 输出张量。所有进程中的该张量应保持相同大小。
- input (Tensor)- 待归约和分散的输入张量。其大小应为输出张量大小乘以进程组规模。输入张量可具有以下两种形状之一:
(i) 沿主维度拼接的输出张量序列,或
(ii) 沿主维度堆叠的输出张量序列。
关于"拼接"的定义,请参阅 torch.cat()。
关于"堆叠"的定义,请参阅 torch.stack()。
- group (ProcessGroup, 可选)- 要操作的进程组。若为None,则使用默认进程组。
- async_op (bool, 可选)- 是否将此操作设为异步操作。
返回
若 async_op 设为 True,返回异步工作句柄。
若未设置 async_op 或不属于该进程组,返回 None。
示例:
>>
>
# All tensors below are of torch.int64 dtype and on CUDA devices.
>>
>
# We have two ranks.
>>
> device = torch.device(f"cuda:{rank
}")
>>
> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
>>
>
# Input in concatenation form
>>
> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
>>
> tensor_in
tensor([0, 1, 2, 3], device='cuda:0') # Rank 0
tensor([0, 1, 2, 3], device='cuda:1') # Rank 1
>>
> dist.reduce_scatter_tensor(tensor_out, tensor_in)
>>
> tensor_out
tensor([0, 2], device='cuda:0') # Rank 0
tensor([4, 6], device='cuda:1') # Rank 1
>>
>
# Input in stack form
>>
> tensor_in = torch.reshape(tensor_in, (world_size, 2))
>>
> tensor_in
tensor([[0, 1], [2, 3]], device='cuda:0') # Rank 0
tensor([[0, 1], [2, 3]], device='cuda:1') # Rank 1
>>
> dist.reduce_scatter_tensor(tensor_out, tensor_in)
>>
> tensor_out
tensor([0, 2], device='cuda:0') # Rank 0
tensor([4, 6], device='cuda:1') # Rank 1警告:Gloo 后端不支持此 API。
torch.distributed.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None, group=None, async_op=False)将输入张量分割后分散到组内所有进程中。
随后从组内所有进程接收到的张量会被拼接起来,作为单个输出张量返回。
支持复数张量。
参数
- output ( Tensor )– 收集拼接后的输出张量。
- input ( Tensor )– 待分散的输入张量。
- output_split_sizes– (list[Int], 可选): 如果指定为None或空列表,则要求- output张量的第0维必须能被- world_size整除;否则指定第0维的输出分割尺寸。
- input_split_sizes– (list[Int], 可选): 如果指定为None或空列表,则要求- input张量的第0维必须能被- world_size整除;否则指定第0维的输入分割尺寸。
- group (ProcessGroup, 可选)– 要操作的工作进程组。如果为None,则使用默认进程组。
- async_op ([bool], 可选)– 是否将此操作设为异步操作。
返回值:如果async_op设为True,则返回异步操作句柄。
如果不设async_op或不属于该进程组,则返回None。
警告:all_to_all_single是实验性功能,后续可能变更。
示例
>>
>
input = torch.arange(4) + rank * 4
>>
>
input
tensor([0, 1, 2, 3]) # Rank 0
tensor([4, 5, 6, 7]) # Rank 1
tensor([8, 9, 10, 11]) # Rank 2
tensor([12, 13, 14, 15]) # Rank 3
>>
> output = torch.empty([4], dtype=torch.int64)
>>
> dist.all_to_all_single(output, input)
>>
> output
tensor([0, 4, 8, 12]) # Rank 0
tensor([1, 5, 9, 13]) # Rank 1
tensor([2, 6, 10, 14]) # Rank 2
tensor([3, 7, 11, 15]) # Rank 3>>
>
# Essentially, it is similar to following operation:
>>
> scatter_list = list(input.chunk(world_size))
>>
> gather_list = list(output.chunk(world_size))
>>
>
for i in range(world_size):
>>
> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)>>
>
# Another example with uneven split
>>
>
input
tensor([0, 1, 2, 3, 4, 5]) # Rank 0
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1
tensor([20, 21, 22, 23, 24]) # Rank 2
tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3
>>
> input_splits
[2, 2, 1, 1] # Rank 0
[3, 2, 2, 2] # Rank 1
[2, 1, 1, 1] # Rank 2
[2, 2, 2, 1] # Rank 3
>>
> output_splits
[2, 3, 2, 2] # Rank 0
[2, 2, 1, 2] # Rank 1
[1, 2, 1, 2] # Rank 2
[1, 2, 1, 1] # Rank 3
>>
> output = ...
>>
> dist.all_to_all_single(output, input, output_splits, input_splits)
>>
> output
tensor([0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0
tensor([2, 3, 13, 14, 22, 32, 33]) # Rank 1
tensor([4, 15, 16, 23, 34, 35]) # Rank 2
tensor([5, 17, 18, 24, 36]) # Rank 3>>
>
# Another example with tensors of torch.cfloat type.
>>
>
input = torch.tensor(
... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
... ) + 4 * rank * (1 + 1j)
>>
>
input
tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0
tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1
tensor([9+9j, 10+10j, 11+11j, 12+12j]) # Rank 2
tensor([13+13j, 14+14j, 15+15j, 16+16j]) # Rank 3
>>
> output = torch.empty([4], dtype=torch.int64)
>>
> dist.all_to_all_single(output, input)
>>
> output
tensor([1+1j, 5+5j, 9+9j, 13+13j]) # Rank 0
tensor([2+2j, 6+6j, 10+10j, 14+14j]) # Rank 1
tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2
tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False)将输入张量列表分散到组内所有进程,并返回聚合后的输出张量列表。
支持复数张量。
参数
- output_tensor_list (list[Tensor])- 每个rank待聚合的张量列表。
- input_tensor_list (list[Tensor])- 每个rank待分散的张量列表。
- group (ProcessGroup, 可选)- 操作的工作进程组。若为None,则使用默认进程组。
- async_op (bool, 可选)- 是否将操作设为异步模式。
返回值
- 若async_op设为True,返回异步操作句柄。
- 若非异步模式或不属于该进程组,则返回None。
警告:all_to_all接口处于实验阶段,后续可能变更。
示例:
>>
>
input = torch.arange(4) + rank * 4
>>
>
input = list(input.chunk(4))
>>
>
input
[tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0
[tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1
[tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2
[tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3
>>
> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
>>
> dist.all_to_all(output, input)
>>
> output
[tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0
[tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1
[tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2
[tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3>>
>
# Essentially, it is similar to following operation:
>>
> scatter_list = input
>>
> gather_list = output
>>
>
for i in range(world_size):
>>
> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)>>
>
input
tensor([0, 1, 2, 3, 4, 5]) # Rank 0
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1
tensor([20, 21, 22, 23, 24]) # Rank 2
tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3
>>
> input_splits
[2, 2, 1, 1] # Rank 0
[3, 2, 2, 2] # Rank 1
[2, 1, 1, 1] # Rank 2
[2, 2, 2, 1] # Rank 3
>>
> output_splits
[2, 3, 2, 2] # Rank 0
[2, 2, 1, 2] # Rank 1
[1, 2, 1, 2] # Rank 2
[1, 2, 1, 1] # Rank 3
>>
>
input = list(input.split(input_splits))
>>
>
input
[tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0
[tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1
[tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2
[tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3
>>
> output = ...
>>
> dist.all_to_all(output, input)
>>
> output
[tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0
[tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1
[tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2
[tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3>>
>
# Another example with tensors of torch.cfloat type.
>>
>
input = torch.tensor(
... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
... ) + 4 * rank * (1 + 1j)
>>
>
input = list(input.chunk(4))
>>
>
input
[tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0
[tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])] # Rank 1
[tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])] # Rank 2
[tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])] # Rank 3
>>
> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
>>
> dist.all_to_all(output, input)
>>
> output
[tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])] # Rank 0
[tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])] # Rank 1
[tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])] # Rank 2
[tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])] # Rank 3torch.distributed.barrier(group=None, async_op=False, device_ids=None)同步所有进程。
如果 async_op 为 False,或者对 wait() 调用了异步工作句柄,该集合操作会阻塞进程,直到整个组进入此函数。
参数
- group (ProcessGroup, 可选)– 要操作的进程组。如果为 None,则使用默认进程组。
- async_op ([bool], 可选)– 该操作是否为异步操作
- device_ids ([int], 可选)– 设备/GPU ID 列表。
返回值:如果 async_op 设为 True,返回异步工作句柄。
如果不为 async_op 或不属于该组,返回 None。
注意:ProcessGroupNCCL 现在会阻塞 CPU 线程,直到屏障集合操作完成。
torch.distributed.monitored_barrier(group=None, timeout=None, wait_all_ranks=False)实现类似torch.distributed.barrier的进程同步功能,但支持可配置的超时机制。
该机制能够报告在指定超时时间内未能通过屏障的进程排名(ranks)。
具体而言:
- 对于非0排名进程,会阻塞直至完成与rank 0的发送/接收操作
- Rank 0进程会阻塞直至处理完所有其他进程的发送/接收操作,并上报超时未响应的进程排名
- 注意:若任一进程未到达monitored_barrier(例如因挂起),所有其他进程都会在monitored_barrier处失败
这个集合操作会阻塞组内所有进程/排名,直到整个组成功退出该函数,因此非常适用于调试和同步场景。但需注意其性能开销,建议仅用于调试或需要主机端完全同步点的场景。调试时可在应用程序的集合调用前插入此屏障,用于检查是否存在进程不同步的情况。
注意:该集合操作仅支持GLOO后端。
参数说明
- group (ProcessGroup, 可选)- 要操作的工作进程组。若为- None则使用默认进程组
- timeout ([datetime.timedelta, 可选)- monitored_barrier的超时时间。若为- None则使用默认进程组超时设置
- wait_all_ranks ([bool], 可选)- 是否收集所有失败进程排名。默认为- False,此时rank 0上的monitored_barrier会在遇到第一个失败排名时立即抛出异常以实现快速失败。若设为- True则会收集所有失败排名并抛出包含全部失败信息的错误
返回值
None
使用示例
>>
>
# Note: Process group initialization omitted on each rank.
>>
>
import torch.distributed as dist
>>
>
if dist.get_rank() != 1:
>>
> dist.monitored_barrier() # Raises exception indicating that >># rank 1 did not call into monitored_barrier.
>>
>
# Example with wait_all_ranks=True
>>
>
if dist.get_rank() == 0:
>>
> dist.monitored_barrier(wait_all_ranks=True) # Raises exception
>>
>
# indicating that ranks 1, 2, 
... world_size - 1 did not call into
>>
>
# monitored_barrier.class torch
.distributed.WorkWork对象代表PyTorch分布式包中一个待处理的异步操作句柄。它由非阻塞的集合操作返回,例如dist.all_reduce(tensor, async_op=True)。
boxed(self: torch._C._distributed_c10d.Work) → object
exception(self: torch._C._distributed_c10d.Work) → std::__exception_ptr::exception_ptr
get_future(self: torch._C._distributed_c10d.Work) → torch.Future
返回值:一个与Work完成相关联的torch.futures.Future对象。例如,可以通过fut = process_group.allreduce(tensors).get_future()获取future对象。
示例:下面是一个简单的allreduce DDP通信钩子示例,它使用get_future API来检索与allreduce完成相关联的Future。
>>
>
def allreduce(process_group: dist.ProcessGroup, bucket: dist.GradBucket): -torch.futures.Future
>>
> group_to_use = process_group if process_group is not None else torch.distributed.group.WORLD
>>
> tensor = bucket.buffer().div_(group_to_use.size())
>>
>
return torch.distributed.all_reduce(tensor, group=group_to_use, async_op=True).get_future()
>>
> ddp_model.register_comm_hook(state=None, hook=allreduce)警告:get_future API 支持 NCCL 后端,部分支持 GLOO 和 MPI 后端(不支持点对点操作如 send/recv),并将返回一个 torch.futures.Future。
在上述示例中,allreduce 操作将通过 NCCL 后端在 GPU 上执行。fut.wait() 会在 NCCL 流与 PyTorch 当前设备流同步后返回,以确保支持异步 CUDA 执行,而无需等待整个 GPU 操作完成。请注意,CUDAFuture 不支持 TORCH_NCCL_BLOCKING_WAIT 标志或 NCCL 的 barrier() 功能。
此外,若通过 fut.then() 添加了回调函数,该回调将等待 WorkNCCL 的 NCCL 流与 ProcessGroupNCCL 的专用回调流同步,并在回调流上执行后立即触发回调。fut.then() 会返回另一个 CUDAFuture,其中包含回调函数的返回值以及记录回调流的 CUDAEvent。
1、对于 CPU 任务,fut.done() 在任务完成且 value() 张量就绪时返回 true。
 2、对于 GPU 任务,fut.done() 仅在操作已加入队列时返回 true。
 3、对于 CPU-GPU 混合任务(例如通过 GLOO 发送 GPU 张量),fut.done() 在张量到达目标节点时返回 true,但 GPU 上的同步可能尚未完成(与纯 GPU 任务类似)。
get_future_result(self: torch._C._distributed_c10d.Work) → torch.Future返回
一个torch.futures.Future类型的对象,其整数值对应WorkResult枚举类型
例如,可以通过fut = process_group.allreduce(tensor).get_future_result()获取future对象。
示例:用户可以使用fut.wait()阻塞等待工作完成,并通过fut.value()获取WorkResult。
此外,用户还可以使用fut.then(call_back_func)注册回调函数,
该函数会在工作完成时被调用,且不会阻塞当前线程。
警告:get_future_result API仅支持NCCL
is_completed(self: torch._C._distributed_c10d.Work) → boolis_success(self: torch._C._distributed_c10d.Work) → boolresult(self: torch._C._distributed_c10d.Work) → list [torch.Tensor]获取工作对象的结果,返回一个包含torch.Tensor的列表
source_rank(self: torch._C._distributed_c10d.Work) → int获取发送该工作对象的源进程排名,返回一个整数值
synchronize(self: torch._C._distributed_c10d.Work) → Nonestatic unbox(arg0: object ) → torch._C._distributed_c10d.Workwait(self: torch._C._distributed_c10d.Work, timeout: [datetime.timedelta = datetime.timedelta(0)) → bool返回值 : true/false。
示例::
try:
work.wait(timeout)
except:
# some handling警告:通常情况下,用户无需设置超时参数。
调用 wait() 等同于调用 synchronize():
会使当前流阻塞直至 NCCL 工作完成。
但如果设置了超时参数,则会阻塞 CPU 线程直至 NCCL 工作完成或超时。若发生超时,将抛出异常。
class torch
.distributed.ReduceOp一个枚举类,用于表示可用的归约操作:SUM(求和)、PRODUCT(乘积)、MIN(最小值)、MAX(最大值)、BAND(按位与)、BOR(按位或)、BXOR(按位异或)以及PREMUL_SUM(预乘求和)。
注意事项:
- 当使用NCCL后端时,BAND、BOR和BXOR归约操作不可用。
- AVG(平均值)会在跨节点求和前将数值除以全局进程数。该操作仅支持- NCCL后端,且要求NCCL版本为2.10及以上。
- PREMUL_SUM会在归约前将输入张量乘以指定的标量。该操作仅支持- NCCL后端,且要求NCCL版本为2.11及以上。用户应使用- torch.distributed._make_nccl_premul_sum来调用。
- 复数张量不支持MAX、MIN和PRODUCT操作。
使用方式:
- 可通过属性访问枚举值,例如ReduceOp.SUM
- 用于指定集合通信的归约策略,例如reduce()
限制说明:
- 本类不支持__members__属性
class torch
.distributed.reduce_op已弃用的枚举式类,用于定义归约操作:SUM(求和)、PRODUCT(乘积)、MIN(最小值)和MAX(最大值)。
建议改用 ReduceOp 类。
分布式键值存储
分布式包内置了一个分布式键值存储,可用于在进程组之间共享信息,也可用于初始化分布式包(通过显式创建存储作为指定 init_method 的替代方案)。键值存储有三种选择:TCPStore、FileStore 和 HashStore。
class torch
.distributed.Store
Base class for
all store implementations, such as the 3 provided by PyTorch
distributed: ([`TCPStore`](https://pytorch.org/docs/stable/data.html#torch.distributed.TCPStore "torch.distributed.TCPStore"), [`FileStore`](https://pytorch.org/docs/stable/data.html#torch.distributed.FileStore "torch.distributed.FileStore"), and [`HashStore`](https://pytorch.org/docs/stable/data.html#torch.distributed.HashStore "torch.distributed.HashStore")).
__init__(self: torch._C._distributed_c10d.Store) → Noneadd(self: torch._C._distributed_c10d.Store, arg0: str , arg1: int ) → int首次对某个 key 调用 add 方法时,会在存储中创建一个与该 key 关联的计数器,并初始化为 amount 值。后续对相同 key 调用 add 方法时,计数器会按指定的 amount 值递增。
若调用 add() 时指定的 key 已被 set() 方法设置过,则会抛出异常。
参数
- key (str)– 存储中待递增计数器的键名
- amount ( int )– 计数器递增的数值量
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
>
# 以TCPStore为例,其他存储类型也可使用
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.add("first_key", 1)
>>
> store.add("first_key", 6)
>>
>
# 应返回7
>>
> store.get("first_key")append(self: torch._C._distributed_c10d.Store, arg0: str , arg1: str ) → None根据提供的 key 和 value 将键值对追加到存储中。如果存储中不存在该 key,则会自动创建。
参数
- key (str)– 要追加到存储中的键名。
- value (str)– 与- key关联并添加到存储中的值。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.append("first_key", "po")
>>
> store.append("first_key", "tato")
>>
>
# Should return "potato"
>>
> store.get("first_key")check(self: torch._C._distributed_c10d.Store, arg0: list[str]) → bool检查给定keys列表是否在存储中有值的调用。该调用在正常情况下会立即返回,但仍可能遇到某些边缘死锁情况,例如在TCPStore已被销毁后调用检查。
调用check()时传入需要检查是否存在于存储中的键列表。
参数
- keys (list[str])– 需要查询是否存在于存储中的键列表。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
>
# 以TCPStore为例,其他存储类型也可使用
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.add("first_key", 1)
>>
>
# 应返回7
>>
> store.check(["first_key"])compare_set(self: torch._C._distributed_c10d.Store, arg0:  str , arg1:  str , arg2:  str ) → bytes根据提供的 key 将键值对插入存储,并在插入前对 expected_value 和 desired_value 进行比较。只有当该 key 对应的 expected_value 已存在于存储中,或 expected_value 为空字符串时,才会设置 desired_value。
参数
- key (str)– 需要在存储中检查的键名。
- expected_value (str)– 插入前需检查的、与- key关联的预期值。
- desired_value (str)– 需要添加到存储中、与- key关联的目标值。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.set("key", "first_value")
>>
> store.compare_set("key", "first_value", "second_value")
>>
>
# 应返回 "second_value"
>>
> store.get("key")delete_key(self: torch._C._distributed_c10d.Store, arg0:  str ) → bool从存储中删除与key关联的键值对。如果键成功删除则返回true,否则返回false。
警告:delete_key API仅支持TCPStore和HashStore。在FileStore上使用此API会引发异常。
参数
- key (str)- 要从存储中删除的键
返回值:如果key被删除则返回True,否则返回False。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
>
# Using TCPStore as an example, HashStore can also be used
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.set("first_key")
>>
>
# This should return true
>>
> store.delete_key("first_key")
>>
>
# This should return false
>>
> store.delete_key("bad_key")get(self: torch._C._distributed_c10d.Store, arg0: str ) → bytes从存储中获取与给定key关联的值。如果key不存在于存储中,该函数将等待初始化存储时定义的timeout时长,然后抛出异常。
参数
- key (str)– 函数将返回与此键关联的值。
返回值:如果key存在于存储中,则返回与之关联的值。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.set("first_key", "first_value")
>>
>
# Should return "first_value"
>>
> store.get("first_key")has_extended_api(self: torch._C._distributed_c10d.Store) → bool如果存储支持扩展操作,则返回 true。
multi_get(self: torch._C._distributed_c10d.Store, arg0: list [str ]) → list [bytes ]
获取 keys 中的所有值。如果 keys 中的任意键不存在于存储中,该函数将等待 timeout。
参数
- keys (List[str])– 要从存储中获取的键列表。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.set("first_key", "po")
>>
> store.set("second_key", "tato")
>>
>
# 应返回 [b"po", b"tato"]
>>
> store.multi_get(["first_key", "second_key"])multi_set(self: torch._C._distributed_c10d.Store, arg0: list [str ], arg1: list [str ]) → None根据提供的 keys 和 values 向存储中插入一个键值对列表
参数
- keys (List[str])– 要插入的键列表
- values (List[str])– 要插入的值列表
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.multi_set(["first_key", "second_key"], ["po", "tato"])
>>
>
# Should return b"po"
>>
> store.get("first_key")num_keys(self: torch._C._distributed_c10d.Store) → int返回存储中设置的键数量。需要注意的是,这个数字通常会比通过set()和add()方法添加的键数量多1,因为其中一个键用于协调所有使用该存储的工作进程。
警告:当与TCPStore一起使用时,num_keys返回的是写入底层文件的键数量。如果存储被销毁后,另一个存储使用同一文件创建,原有的键仍会被保留。
返回值:存储中当前存在的键数量。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
>
# 以TCPStore为例,也可以使用其他存储类型
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.set("first_key", "first_value")
>>
>
# 这里应该返回2
>>
> store.num_keys()set(self: torch._C._distributed_c10d.Store, arg0:  str , arg1:  str ) → None根据提供的 key 和 value 将键值对插入存储中。如果 key 已存在于存储中,则会用新提供的 value 覆盖旧值。
参数
- key (str)– 要添加到存储中的键。
- value (str)– 与- key关联并要添加到存储中的值。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.set("first_key", "first_value")
>>
>
# Should return "first_value"
>>
> store.get("first_key")set_timeout(self: torch._C._distributed_c10d.Store, arg0: [datetime.timedelta) → None设置存储的默认超时时间。该超时时间会在初始化期间以及在 wait() 和 get() 方法中使用。
参数
- timeout (timedelta)– 要设置到存储中的超时时间。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
>
# 以TCPStore为例,也可以使用其他存储类型
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
> store.set_timeout(timedelta(seconds=10))
>>
>
# 10秒后将抛出异常
>>
> store.wait(["bad_key"])property timeout获取存储的超时设置。
wait(*args, **kwargs)
这是一个重载函数。
1、wait(self: torch._C._distributed_c10d.Store, arg0: list[str]) -None
等待keys列表中的每个键被添加到存储中。如果在timeout(存储初始化时设置)之前未设置所有键,则wait将抛出异常。
参数
- keys (list)– 需要等待的键列表,直到它们在存储中被设置。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
>
# Using TCPStore as an example, other store types can also be used
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
>
# This will throw an exception after 30 seconds
>>
> store.wait(["bad_key"])2、wait(self: torch._C._distributed_c10d.Store, arg0: list[str], arg1: datetime.timedelta) -None
等待keys中的每个键被添加到存储中,如果在指定的timeout时间内这些键未被设置,则抛出异常。
参数说明
- keys (list)– 需要等待其被设置到存储中的键列表。
- timeout (timedelta)– 在抛出异常前等待键被添加的最长时间。
使用示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
>
# Using TCPStore as an example, other store types can also be used
>>
> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>
>
# This will throw an exception after 10 seconds
>>
> store.wait(["bad_key"], timedelta(seconds=10))class torch
.distributed.TCPStore基于TCP协议的分布式键值存储实现。服务器端存储数据,客户端存储可以通过TCP连接到服务器存储,并执行诸如set()插入键值对、get()获取键值对等操作。必须始终初始化一个服务器存储,因为客户端存储会等待服务器建立连接。
参数
- host_name (str)– 服务器存储应运行的主机名或IP地址。
- port (int)– 服务器存储监听传入请求的端口号。
- world_size (int, 可选)– 存储用户总数(客户端数量 + 1个服务器)。默认为None(None表示存储用户数量不固定)。
- is_master ([bool], 可选)– 初始化服务器存储时为True,客户端存储时为False。默认为False。
- timeout (timedelta, 可选)– 存储初始化及- get()、- wait()等方法使用的超时时间。默认为timedelta(seconds=300)。
- wait_for_workers ([bool], 可选)– 是否等待所有工作节点与服务器存储建立连接。仅当world_size为固定值时适用。默认为True。
- multi_tenant ([bool], 可选)– 若为True,当前进程中具有相同host/port的所有- TCPStore实例将共享同一个底层- TCPServer。默认为False。
- master_listen_fd (int, 可选)– 若指定,底层- TCPServer将监听此文件描述符(必须为已绑定到- port的套接字)。适用于避免某些场景下的端口分配竞争。默认为None(表示服务器创建新套接字并尝试绑定到- port)。
- use_libuv ([bool], 可选)– 若为True,使用libuv作为- TCPServer后端。默认为True。
示例:
>>
>
import torch.distributed as dist
>>
>
from datetime import timedelta
>>
>
# 在进程1(服务端)运行
>>
> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))
>>
>
# 在进程2(客户端)运行
>>
> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)
>>
>
# 初始化后,客户端或服务端均可使用存储方法
>>
> server_store.set("first_key", "first_value")
>>
> client_store.get("first_key")__init__(self: [torch._C._distributed_c10d.TCPStore](https://pytorch.org/docs/stable/data.html#torch.distributed.TCPStore "torch._C._distributed_c10d.TCPStore"), host_name:  str , port:  int , world_size: Optional[int ] = None, is_master:  bool  = False, timeout: [datetime.timedelta = datetime.timedelta(seconds=300), wait_for_workers:  bool  = True, multi_tenant:  bool  = False, master_listen_fd: Optional[int ] = None, use_libuv:  bool  = True) → None创建一个新的 TCPStore。
property host获取存储服务监听请求的主机名。
property libuvBackend返回 True 表示当前正在使用 libuv 后端。
property port获取存储服务监听请求的端口号。
class torch
.distributed.HashStore一个基于底层哈希映射的线程安全存储实现。该存储可以在同一进程内使用(例如被其他线程使用),但不能跨进程使用。
示例:
>>
>
import torch.distributed as dist
>>
> store = dist.HashStore()
>>
>
# store can be used from other threads
>>
>
# Use any of the store methods after initialization
>>
> store.set("first_key", "first_value")__init__(self: [torch._C._distributed_c10d.HashStore](https://pytorch.org/docs/stable/data.html#torch.distributed.HashStore "torch._C._distributed_c10d.HashStore")) → None创建一个新的 HashStore。
class torch
.distributed.FileStore一个使用文件存储底层键值对的存储实现。
参数
- file_name (str)– 用于存储键值对的文件路径
- world_size ( int , 可选)– 使用该存储的进程总数。默认为-1(负值表示存储用户数量不固定)。
示例:
>>
>
import torch.distributed as dist
>>
> store1 = dist.FileStore("/tmp/filestore", 2)
>>
> store2 = dist.FileStore("/tmp/filestore", 2)
>>
>
# Use any of the store methods from either the client or server after initialization
>>
> store1.set("first_key", "first_value")
>>
> store2.get("first_key")__init__(self: torch._C._distributed_c10d.FileStore, file_name: str, world_size: int = -1) → None创建一个新的 FileStore。
property path获取FileStore用于存储键值对的文件路径。
class torch
.distributed.PrefixStore对三种键值存储(TCPStore、FileStore 和 HashStore)的封装器,会在每个存入存储的键前添加前缀。
参数
- prefix (str)- 在键存入存储前添加的前缀字符串。
- store (torch.distributed.store)- 作为底层键值存储的存储对象。
__init__(self: torch._C._distributed_c10d.PrefixStore, prefix: str , store: torch._C._distributed_c10d.Store) → None创建一个新的 PrefixStore。
property underlying_store获取 PrefixStore 所封装的基础存储对象。
分析集体通信性能
请注意,您可以使用 torch.profiler(推荐使用,仅1.8.1版本后可用)或 torch.autograd.profiler 来分析本文提到的集体通信和点对点通信API。所有开箱即用的后端(gloo、nccl、mpi)都支持性能分析,集体通信的使用情况将在分析输出/跟踪中按预期呈现。分析代码的方式与常规的torch运算符完全相同:
import torch
import torch.distributed as dist
with torch.profiler():
tensor = torch.randn(20, 10)
dist.all_reduce(tensor)请参阅 性能分析器文档 以获取性能分析器功能的完整概述。
多GPU集合函数
警告:多GPU函数(指每个CPU线程对应多个GPU)已被弃用。目前,PyTorch分布式推荐采用每个线程对应一个设备的编程模型,本文档中的API即体现了这一模式。如果您是后端开发者且需要支持每个线程管理多个设备,请联系PyTorch分布式维护团队。
第三方后端
除了内置的 GLOO/MPI/NCCL 后端外,PyTorch 分布式模块通过运行时注册机制支持第三方后端。关于如何通过 C++ 扩展开发第三方后端的参考文档,请查阅 教程 - 自定义 C++ 和 CUDA 扩展 以及 test/cpp_extensions/cpp_c10d_extension.cpp。第三方后端的功能由其自身实现决定。
新后端需要继承自 c10d::ProcessGroup,并在导入时通过 torch.distributed.Backend.register_backend() 注册后端名称和实例化接口。
当手动导入该后端并通过指定后端名称调用 torch.distributed.init_process_group() 时,torch.distributed 包将运行在新的后端上。
警告:第三方后端支持目前处于实验阶段,后续可能发生变更。
启动工具
torch.distributed 包还在 torch.distributed.launch 中提供了一个启动工具。这个辅助工具可用于在每个节点上启动多个进程进行分布式训练。
模块 torch.distributed.launch。
torch.distributed.launch 是一个模块,可在每个训练节点上生成多个分布式训练进程。
警告:该模块将被 torchrun 取代。
该工具可用于单节点分布式训练,其中每个节点会生成一个或多个进程。该工具既可用于 CPU 训练,也可用于 GPU 训练。如果用于 GPU 训练,每个分布式进程将在单个 GPU 上运行。这可以显著提升单节点训练性能。它也可用于多节点分布式训练,通过在每个节点上生成多个进程,同样显著提升多节点分布式训练性能。这对于具有多个支持直接 GPU 的 Infiniband 接口的系统尤其有益,因为所有这些接口都可以用于聚合通信带宽。
无论是单节点分布式训练还是多节点分布式训练,该工具都会在每个节点上启动指定数量的进程(--nproc-per-node)。如果用于 GPU 训练,这个数字需要小于或等于当前系统上的 GPU 数量(nproc_per_node),并且每个进程将在 GPU 0 到 GPU (nproc_per_node - 1) 上运行。
如何使用该模块:
1、单节点多进程分布式训练
python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
arguments of your training script)2、多节点多进程分布式训练:(例如两个节点)
节点1:(IP: 192.168.1.1,空闲端口:1234)
python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
--nnodes=2 --node-rank=0 --master-addr="192.168.1.1"
--master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
and all other arguments of your training script)Node 2:
python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
--nnodes=2 --node-rank=1 --master-addr="192.168.1.1"
--master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
and all other arguments of your training script)3、要查看该模块提供的可选参数:
python -m torch.distributed.launch --help重要注意事项:
1、当前该工具及多进程分布式(单节点或多节点)GPU训练仅在NCCL分布式后端下才能实现最佳性能。因此,推荐在GPU训练中使用NCCL后端。
2、在训练程序中,必须解析命令行参数:
--local-rank=LOCAL_PROCESS_RANK(该参数将由本模块提供)。
若训练程序使用GPU,需确保代码仅在LOCAL_PROCESS_RANK对应的GPU设备上运行。可通过以下方式实现:
解析local_rank参数
>>
>
import argparse
>>
> parser = argparse.ArgumentParser()
>>
> parser.add_argument("--local-rank", "--local_rank", type=int)
>>
> args = parser.parse_args()将您的设备设置为本地等级,可通过以下方式实现:
>>
> torch.cuda.set_device(args.local_rank) # 在代码运行前执行此操作or
>>
>
with torch.cuda.device(args.local_rank):
>>
>
# 在此处运行你的代码
>>
>
...版本 2.0.0 变更:启动器会向您的脚本传递 --local-rank=<rank> 参数。
从 PyTorch 2.0.0 开始,推荐使用带连字符的 --local-rank 而非之前使用的带下划线形式 --local_rank。
为了保持向后兼容性,用户可能需要在参数解析代码中同时处理这两种情况。这意味着在参数解析器中需要同时包含 "--local-rank" 和 "--local_rank"。如果仅提供 "--local_rank",启动器会报错:“error: unrecognized arguments: –local-rank=”。对于仅支持 PyTorch 2.0.0+ 的训练代码,包含 "--local-rank" 应该就足够了。
3、在您的训练程序中,应当在开始时调用以下函数来启动分布式后端。强烈建议使用 init_method=env://。其他初始化方法(如 tcp://)可能有效,但 env:// 是本模块官方支持的方式。
>>
> torch.distributed.init_process_group(backend='YOUR BACKEND',
init_method='env://')在训练程序中,您可以选择使用常规的分布式函数,也可以使用 torch.nn.parallel.DistributedDataParallel() 模块。如果您的训练程序使用 GPU 进行训练,并且希望使用 torch.nn.parallel.DistributedDataParallel() 模块,以下是配置方法。
>>
> model = torch.nn.parallel.DistributedDataParallel(model, >> device_ids=[args.local_rank], >> output_device=args.local_rank)请确保将 device_ids 参数设置为代码将操作的唯一 GPU 设备 ID。这通常是进程的本地排名(local rank)。换句话说,要使用此工具,device_ids 需设为 [args.local_rank],且 output_device 需设为 args.local_rank。
5、另一种通过环境变量 LOCAL_RANK 向子进程传递 local_rank 的方法:当使用 --use-env=True 启动脚本时,此功能会自动启用。你必须修改上述子进程示例,将 args.local_rank 替换为 os.environ['LOCAL_RANK'];若指定该标志,启动器将不会传递 --local-rank 参数。
警告:local_rank 并非全局唯一,它仅在单台机器的进程内唯一。因此,切勿用它来决定是否执行诸如写入网络文件系统等操作。若未正确处理,可能导致问题,具体案例可参考 https://github.com/pytorch/pytorch/issues/12042。
生成进程工具
多进程包 - torch.multiprocessing 提供了 torch.multiprocessing.spawn() 中的 spawn 函数。这个辅助函数可用于生成多个进程,其工作原理是传入目标执行函数,然后创建N个进程来运行该函数。该工具也可用于多进程分布式训练。
具体用法示例可参考 PyTorch示例 - ImageNet实现
注意:此功能需要Python 3.4或更高版本。
调试 torch.distributed 应用程序
由于难以理解的挂起、崩溃或跨进程的不一致行为,调试分布式应用程序可能具有挑战性。torch.distributed 提供了一套工具,以自助方式帮助调试训练应用程序:
Python 断点调试
在分布式环境中使用Python调试器极为便利,但由于开箱即用功能不足,许多人完全未使用它。PyTorch提供了一个定制化的pdb封装器,可简化这一流程。
torch.distributed.breakpoint 使该过程变得简单。其内部通过两种方式定制pdb的断点行为,其余功能与常规pdb一致:
 1、仅在被用户指定的特定rank上附加调试器
 2、通过调用torch.distributed.barrier()确保其他所有rank暂停运行,该屏障会在被调试rank发出继续指令后解除
 3、将子进程的标准输入重定向至您的终端
使用时,只需在所有rank上调用torch.distributed.breakpoint(rank),并确保各rank传入相同的rank值即可。
监控式屏障
从 v1.10 版本开始,torch.distributed.monitored_barrier() 作为 torch.distributed.barrier() 的替代方案存在。当发生崩溃时(即并非所有 rank 在指定超时时间内调用 torch.distributed.monitored_barrier()),该函数会提供有关可能故障 rank 的有用信息。torch.distributed.monitored_barrier() 通过类似确认机制的 send/recv 通信原语在主机端实现屏障功能,使得 rank 0 能够报告哪些 rank 未能及时确认屏障。
例如,考虑以下场景:rank 1 未能调用 torch.distributed.monitored_barrier()(实际中可能由于应用程序错误或前一个集合操作挂起导致):
import os
from datetime import timedelta
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def worker(rank):
dist.init_process_group("nccl", rank=rank, world_size=2)监控屏障需要 gloo 进程组执行主机端同步
group_gloo = dist.new_group(backend="gloo")
if rank not in [1]:
dist.monitored_barrier(group=group_gloo, timeout=timedelta(seconds=2))
if __name__ == "__main__":
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
mp.spawn(worker, nprocs=2, args=())在 rank 0 上会产生以下错误信息,使用户能够判断哪些 rank 可能出现故障并进行进一步排查:
RuntimeError: Rank 1 failed to pass monitoredBarrier in 2000 ms
Original exception:
[gloo/transport/tcp/pair.cc:598] Connection closed by peer [2401:db00:eef0:1100:3560:0:1c05:25d]:8594说明:
1、保留了代码块格式和所有技术术语(如RuntimeError、Rank、monitoredBarrier)
2、将被动语态"Connection closed by peer"转换为主动语态"对等方关闭了连接"
3、保持了IP地址和端口号的原始格式
4、错误信息路径[gloo/transport/tcp/pair.cc:598]保持原样
5、时间单位"ms"转换为中文习惯的"毫秒"
TORCH_DISTRIBUTED_DEBUG
当设置 TORCH_CPP_LOG_LEVEL=INFO 时,环境变量 TORCH_DISTRIBUTED_DEBUG 可用于触发额外的有用日志记录和集体同步检查,以确保所有进程能正确同步。根据所需的调试级别,TORCH_DISTRIBUTED_DEBUG 可设置为 OFF(默认)、INFO 或 DETAIL。请注意,最详细的选项 DETAIL 可能会影响应用程序性能,因此应仅在调试问题时使用。
设置 TORCH_DISTRIBUTED_DEBUG=INFO 会在初始化使用 torch.nn.parallel.DistributedDataParallel() 训练的模型时生成额外的调试日志;而设置 TORCH_DISTRIBUTED_DEBUG=DETAIL 还会在选定的迭代次数中记录运行时性能统计信息。这些运行时统计信息包括前向传播时间、反向传播时间、梯度通信时间等数据。例如,给定以下应用程序:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
class TwoLinLayerNet
(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.Linear(10, 10, bias=False)
self.b = torch.nn.Linear(10, 1, bias=False)
def forward(self, x):
a = self.a(x)
b = self.b(x)
return (a, b)
def worker(rank):
dist.init_process_group("nccl", rank=rank, world_size=2)
torch.cuda.set_device(rank)
print("init model")
model = TwoLinLayerNet().cuda()
print("init ddp")
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
inp = torch.randn(10, 10).cuda()
print("train")
for _ in range(20):
output = ddp_model(inp)
loss = output[0] + output[1]
loss.sum().backward()
if __name__ == "__main__":
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
os.environ[
"TORCH_DISTRIBUTED_DEBUG"
] = "DETAIL" # set to DETAIL for runtime logging.
mp.spawn(worker, nprocs=2, args=())初始化时会渲染以下日志:
I0607 16:10:35.739390 515217 logger.cpp:173] [Rank 0]: DDP Initialized with:
broadcast_buffers: 1
bucket_cap_bytes: 26214400
find_unused_parameters: 0
gradient_as_bucket_view: 0
is_multi_device_module: 0
iteration: 0
num_parameter_tensors: 2
output_device: 0
rank: 0
total_parameter_size_bytes: 440
world_size: 2
backend_name: nccl
bucket_sizes: 440
cuda_visible_devices: N/A
device_ids: 0
dtypes: float
master_addr: localhost
master_port: 29501
module_name: TwoLinLayerNet
nccl_async_error_handling: N/A
nccl_blocking_wait: N/A
nccl_debug: WARN
nccl_ib_timeout: N/A
nccl_nthreads: N/A
nccl_socket_ifname: N/A
torch_distributed_debug: INFO运行时(当设置 TORCH_DISTRIBUTED_DEBUG=DETAIL 时)会显示以下日志:
I0607 16:18:58.085681 544067 logger.cpp:344] [Rank 1 / 2] Training TwoLinLayerNet unused_parameter_size=0
Avg forward compute time: 40838608
Avg backward compute time: 5983335
Avg backward comm. time: 4326421
Avg backward comm/comp overlap time: 4207652
I0607 16:18:58.085693 544066 logger.cpp:344] [Rank 0 / 2] Training TwoLinLayerNet unused_parameter_size=0
Avg forward compute time: 42850427
Avg backward compute time: 3885553
Avg backward comm. time: 2357981
Avg backward comm/comp overlap time: 2234674此外,TORCH_DISTRIBUTED_DEBUG=INFO 增强了 torch.nn.parallel.DistributedDataParallel() 中因模型存在未使用参数导致的崩溃日志记录。当前,如果前向传播中存在可能未被使用的参数,必须在初始化 torch.nn.parallel.DistributedDataParallel() 时传入 find_unused_parameters=True。从 v1.10 开始,由于 torch.nn.parallel.DistributedDataParallel() 不支持反向传播中存在未使用参数,所有模型输出都必须参与损失计算。这些限制对大型模型尤其具有挑战性。
因此当发生错误崩溃时,torch.nn.parallel.DistributedDataParallel() 会记录所有未被使用参数的完全限定名称。例如在上述应用中,如果将损失计算改为 loss = output[1],那么 TwoLinLayerNet.a 在反向传播中不会接收梯度,从而导致 DDP 失败。崩溃时,系统会向用户提供关于未使用参数的信息——这对于大型模型而言可能难以手动定位。
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing
the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by
making sure all `forward` function outputs participate in calculating loss.
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return va
lue of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameters which did not receive grad for rank 0: a.weight
Parameter indices which did not receive grad for rank 0: 0设置 TORCH_DISTRIBUTED_DEBUG=DETAIL 会触发对用户发起的每个集体调用(无论是直接调用还是间接调用,例如 DDP 的 allreduce)进行额外的同步性和一致性检查。具体实现方式是创建一个包装器进程组,该包装器会包裹所有通过 torch.distributed.init_process_group() 和 torch.distributed.new_group() API 返回的进程组。因此,这些 API 将返回一个包装器进程组,其使用方式与常规进程组完全相同,但在将集体操作分发给底层进程组之前会执行一致性检查。
目前,这些检查包括调用 torch.distributed.monitored_barrier(),该操作会确保所有节点完成未完成的集体调用,并报告卡住的节点。接着,系统会通过验证所有集体函数是否匹配且使用一致的张量形状来检查集体操作本身的一致性。如果不符合条件,应用程序崩溃时会提供包含详细错误信息的报告,而不是直接挂起或返回无意义的错误消息。例如,考虑以下函数中传入 torch.distributed.all_reduce() 的张量形状不匹配的情况:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def worker(rank):
dist.init_process_group("nccl", rank=rank, world_size=2)
torch.cuda.set_device(rank)
tensor = torch.randn(10 if rank == 0 else 20).cuda()
dist.all_reduce(tensor)
torch.cuda.synchronize(device=rank)
if __name__ == "__main__":
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
mp.spawn(worker, nprocs=2, args=())使用NCCL后端时,这类应用很可能会导致程序挂起,在复杂场景下难以定位根本原因。如果用户启用
TORCH_DISTRIBUTED_DEBUG=DETAIL并重新运行应用,以下错误信息会揭示根本原因:
work = default_pg.allreduce([tensor], opts)
RuntimeError: Error when verifying shape tensors for collective ALLREDUCE on rank 0、This likely indicates that input shapes into the collective are mismatched across ranks. Got shapes: 10
[ torch.LongTensor{
1
} ]注意:如需在运行时对调试级别进行细粒度控制,还可以使用以下函数:torch.distributed.set_debug_level()、torch.distributed.set_debug_level_from_env() 和 torch.distributed.get_debug_level()。
此外,可以将 TORCH_DISTRIBUTED_DEBUG=DETAIL 与 TORCH_SHOW_CPP_STACKTRACES=1 结合使用,以便在检测到集合操作不同步时记录完整的调用堆栈。这些集合操作不同步检查适用于所有使用 c10d 集合调用的应用程序,这些调用由通过 torch.distributed.init_process_group() 和 torch.distributed.new_group() API 创建的进程组支持。
日志记录
除了通过 torch.distributed.monitored_barrier() 和 TORCH_DISTRIBUTED_DEBUG 提供的显式调试支持外,torch.distributed 的底层 C++ 库还会输出不同级别的日志消息。这些消息有助于理解分布式训练作业的执行状态,并排查诸如网络连接故障等问题。下表展示了如何通过组合 TORCH_CPP_LOG_LEVEL 和 TORCH_DISTRIBUTED_DEBUG 环境变量来调整日志级别。
| TORCH_CPP_LOG_LEVEL | TORCH_DISTRIBUTED_DEBUG | 实际日志级别 | 
|---|---|---|
| ERROR | 忽略 | 错误 | 
| WARNING | 忽略 | 警告 | 
| INFO | 忽略 | 信息 | 
| INFO | INFO | 调试 | 
| INFO | DETAIL | 跟踪(即全部) | 
分布式组件会抛出从 RuntimeError 派生的自定义异常类型:
- torch.distributed.DistError:这是所有分布式异常的基类型。
- torch.distributed.DistBackendError:当发生后端特定错误时抛出此异常。例如,如果使用 NCCL 后端且用户尝试使用 NCCL 库不可用的 GPU。
- torch.distributed.DistNetworkError:当网络库遇到错误时抛出此异常(例如:连接被对端重置)。
- torch.distributed.DistStoreError:当 Store 遇到错误时抛出此异常(例如:TCPStore 超时)。
class torch
.distributed.DistError分布式库中发生错误时引发的异常
class torch
.distributed.DistBackendError当分布式系统中发生后端错误时引发的异常
class torch
.distributed.DistNetworkError分布式系统中发生网络错误时引发的异常
class torch
.distributed.DistStoreError分布式存储发生错误时引发的异常
如果正在运行单节点训练,可以方便地以交互方式在脚本中设置断点。我们提供了一种便捷的方法来为单个 rank 设置断点:
torch.distributed.breakpoint(rank=0, skip=0)功能说明
设置断点,但仅对单个指定rank生效。其他所有rank会等待该断点执行完成后才继续运行。
参数说明
- rank (int)– 指定触发断点的rank编号,默认为- 0
- skip (int)– 跳过前- skip次对该断点的调用,默认为- 0
torch.distributed.tensor
注意:torch.distributed.tensor 目前处于 alpha 开发阶段,文档中列出的大部分 API 我们将确保向后兼容性,但必要时可能会进行 API 变更。
PyTorch DTensor(分布式张量)
PyTorch DTensor 提供简单灵活的张量分片原语,能够透明处理分布式逻辑,包括跨设备/主机的分片存储、算子计算和集合通信。DTensor 可用于构建不同的并行解决方案,并支持在多维分片场景下表示分片状态的 state_dict。
以下是基于 DTensor 构建的 PyTorch 原生并行方案示例:
DTensor 遵循 SPMD(单程序多数据)编程模型,让用户能够像编写具有相同收敛特性的单设备程序那样编写分布式程序。它通过指定 DeviceMesh 和 Placement 提供统一的张量分片布局(DTensor 布局):
- DeviceMesh使用 n 维数组表示集群的设备拓扑和通信器
- Placement描述逻辑张量在- DeviceMesh上的分片布局
 DTensor 支持三种分片类型:- Shard(分片)、- Replicate(复制)和- Partial(部分)。
DTensor 类 API
DTensor 是 torch.Tensor 的子类。这意味着一旦创建了 DTensor,就可以以与 torch.Tensor 非常相似的方式使用它,包括运行不同类型的 PyTorch 操作符,就像在单个设备上运行它们一样,同时为 PyTorch 操作符提供正确的分布式计算支持。
除了现有的 torch.Tensor 方法外,它还提供了一组额外的方法来与 torch.Tensor 交互、将 DTensor 布局重新分配到新的 DTensor、获取所有设备上的完整张量内容等。
class torch
.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)DTensor(分布式张量)是 torch.Tensor 的子类,它为多设备 torch.Tensor 提供了类似单设备的编程抽象。它通过 DeviceMesh 和以下类型的 Placement 来描述分布式张量的分片布局(DTensor Layout):
- Shard:张量在- DeviceMesh维度的设备上沿张量维度- dim分片
- Replicate:张量在- DeviceMesh维度的设备上完整复制
- Partial:张量在- DeviceMesh维度的设备上待规约
当调用 PyTorch 算子时,DTensor 会重载这些算子以执行分片计算,并在必要时发起通信。在算子计算过程中,DTensor 会根据算子本身的语义正确转换或传播布局(DTensor Layout),并生成新的 DTensor 输出。
为确保调用 PyTorch 算子时 DTensor 分片计算的数值正确性,DTensor 要求算子的每个 Tensor 参数都必须是 DTensor。
注意:直接使用 Tensor 子类构造函数创建 DTensor 并非推荐方式(例如它无法正确处理自动求导,因此不属于公开 API)。请参阅 create_dtensor 章节了解如何正确创建 DTensor。
返回类型:DTensor
__create_chunk_list__()返回一个 ChunkStorageMetadata 列表,该数据类用于描述当前 rank 上本地分片/副本的大小和偏移量。对于 DTensor,每个 rank 只会有一个本地分片/副本,因此返回的列表通常仅包含一个元素。
此双下划线方法主要用于分布式检查点用途。
返回值:一个 List[ChunkStorageMetadata] 对象,表示当前 rank 上的分片大小/偏移量。
static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)根据指定的 device_mesh 和 placements,从各 rank 上的本地 torch.Tensor 创建一个 DTensor
参数
- local_tensor (torch.Tensor)– 各 rank 上的本地 torch.Tensor。
- device_mesh (DeviceMesh- , 可选)– 用于放置张量的 DeviceMesh。若未指定,则必须在 DeviceMesh 上下文管理器中调用,默认值:None
- placements (List[Placement- ], 可选)– 描述如何将本地 torch.Tensor 放置在 DeviceMesh 上的布局列表,其元素数量必须与- device_mesh.ndim相同。
关键字参数
- run_check ([bool], 可选)– 以额外通信为代价,跨 rank 执行完整性检查,验证各本地张量的元信息以确保正确性。若- placements中包含- Replicate,设备网格维度的第一个 rank 上的数据将被广播到其他 rank。默认值:False
- shape ( torch.Size , 可选)– 指定构建在 local_tensor 之上的 DTensor 大小的整型列表。注意:当各 rank 上- local_tensor的形状不同时必须提供此参数。若未提供,将假设给定的分布式张量均匀分片到各 rank 来计算- shape。默认值:None
- stride ( tuple , 可选)– 指定 DTensor 步长的整型列表。若未提供,将假设给定的分布式张量均匀分片到各 rank 来计算- stride。默认值:None
返回
一个 DTensor 对象
返回类型:DTensor
注意:当 run_check=False 时,用户需自行确保传入的本地张量在各 rank 间正确(即对于 Shard(dim) 布局张量需分片,对于 Replicate() 布局需复制)。否则,所创建 DTensor 的行为将是未定义的。
注意:from_local 是可微操作,创建的 DTensor 对象的 requires_grad 属性将取决于 local_tensor 是否 requires_grad。
full_tensor(*, grad_placements=None)返回该DTensor的完整张量。该方法会执行必要的集合通信操作,从所在DeviceMesh的其他rank上收集本地张量并进行拼接。这是以下代码的语法糖:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()
关键字参数
- grad_placements (List[Placement- ], 可选)– 该参数描述了从本函数返回的完整张量对应的梯度布局的未来分布方式。
full_tensor将DTensor转换为完整的torch.Tensor,但返回的torch.tensor在后续代码中可能不会保持原始复制的DTensor布局。这个参数是用户提供给autograd的提示,用于处理返回张量的梯度布局与原始复制的DTensor布局不匹配的情况。如果未指定,我们将假定完整张量的梯度布局为复制式分布。
返回值:一个表示该DTensor完整张量的torch.Tensor对象。
返回类型: Tensor
注意:full_tensor是可微分的。
redistribute(device_mesh=None, placements=None, *, async_op=False)redistribute 执行必要的集体操作,将当前 DTensor 从其现有布局重新分配到新布局,或从当前 DeviceMesh 迁移到新 DeviceMesh。例如,我们可以通过为 DeviceMesh 的每个维度指定 Replicate 布局,将分片(Sharded)DTensor 转换为复制(Replicated)DTensor。
当在 DeviceMesh 的某个维度上从当前布局重新分配到新布局时,将执行以下包含通信集体操作或本地操作:
1、Shard(dim) → Replicate():all_gather
2、Shard(src_dim) → Shard(dst_dim):all_to_all
3、Replicate() → Shard(dim):本地分块(即 torch.chunk)
4、Partial() → Replicate():all_reduce
5、Partial() → Shard(dim):reduce_scatter
redistribute 能够正确推断出针对在 1-D 或 N-D DeviceMesh 上创建的 DTensor 所需的重新分配步骤。
参数
- device_mesh (DeviceMesh- , 可选)– 用于放置 DTensor 的 DeviceMesh。若未指定,则使用当前 DTensor 的 DeviceMesh。
默认值:None
- placements (List[Placement- ], 可选)– 描述如何将 DTensor 放置到 DeviceMesh 中的新布局,其元素数量必须与- device_mesh.ndim相同。
默认值:在所有网格维度上复制(replicate)
关键字参数
- async_op ([bool], 可选)– 是否以异步方式执行 DTensor 重新分配操作。默认值:False
返回
一个 DTensor 对象
返回类型
注意:redistribute 是可微分的,这意味着用户无需担心重新分配操作的反向传播公式。
注意:redistribute 当前仅支持在同一 DeviceMesh 上重新分配 DTensor。若需将 DTensor 重新分配到不同 DeviceMesh,请提交问题。
to_local(*, grad_placements=None)获取当前 rank 上该 DTensor 的本地张量。对于分片情况,返回逻辑张量视图的本地分片;对于复制情况,返回当前 rank 上的副本。
关键字参数
- grad_placements (List[Placement- ], 可选)– 该参数描述从本函数返回张量的梯度未来布局。
to_local 将 DTensor 转换为本地张量,且返回的本地张量后续可能不会沿用原 DTensor 的布局。此参数是用户提供给自动求导的提示,用于处理返回张量的梯度布局与原 DTensor 不匹配的情况。若未指定,则默认梯度布局与原 DTensor 相同并用于梯度计算。
返回值:一个 torch.Tensor 或 AsyncCollectiveTensor 对象,表示当前 rank 上的本地张量。当返回 AsyncCollectiveTensor 对象时,意味着本地张量尚未就绪(即通信未完成)。此时用户需调用 wait 方法等待本地张量准备就绪。
返回类型 : Tensor
注意:to_local 是可微分的,返回本地张量的 requires_grad 属性将取决于原 DTensor 是否要求梯度。
property device_mesh: [DeviceMesh](distributed.html#torch.distributed.device_mesh.DeviceMesh "torch.distributed.device_mesh.DeviceMesh")与该 DTensor 对象关联的 DeviceMesh 属性。
注意:device_mesh 是一个只读属性,不可被设置。
property placements: tuple [[torch.distributed.tensor.placement_types.Placement](https://pytorch.org/docs/stable/data.html#torch.distributed.tensor.placement_types.Placement "torch.distributed.tensor.placement_types.Placement"),
...]该 DTensor 的 placements 属性描述了其在设备网格(DeviceMesh)上的分布布局。
注意:placements 是只读属性,不可被修改。
作为分布式通信器的DeviceMesh
DeviceMesh基于DTensor构建,用于抽象描述集群设备拓扑结构,并作为多维通信器(基于ProcessGroup)的载体。如需了解如何创建/使用DeviceMesh的具体细节,请参阅DeviceMesh使用指南。
DTensor 布局类型
DTensor 支持在每个 DeviceMesh 维度上使用以下 Placement 类型:
class torch
.distributed.tensor.placement_types.Shard(dim)Shard(dim)布局描述了张量在维度dim上跨对应DeviceMesh维度的分片方式,其中DeviceMesh维度上的每个rank仅持有全局张量的一个分片。Shard(dim)布局遵循torch.chunk(dim)语义——当张量维度无法在DeviceMesh维度上均匀划分时,DeviceMesh维度上的最后几个分片可能为空。所有DTensor API(如distribute_tensor、from_local等)均可使用Shard布局。
参数
- dim (int)- 指定张量在对应DeviceMesh维度上进行分片的维度编号。
警告:当前对无法在DeviceMesh维度上均匀划分的张量维度进行分片属于实验性功能,后续可能变更。
dim: int
class torch
.distributed.tensor.placement_types.ReplicateReplicate()布局描述了DTensor在对应的DeviceMesh维度上进行复制的行为,其中DeviceMesh维度上的每个rank都持有全局Tensor的一个副本。所有DTensor API(例如distribute_tensor、DTensor.from_local等)都可以使用Replicate布局。
class torch
.distributed.tensor.placement_types.Partial(reduce_op='sum')Partial(reduce_op)布局描述了在指定DeviceMesh维度上待归约的DTensor,其中DeviceMesh维度的每个rank持有全局Tensor的部分值。用户可以通过redistribute将Partial DTensor转换为指定DeviceMesh维度上的Replicate或Shard(dim)布局,这将触发底层的必要通信操作(如allreduce、reduce_scatter)。
参数
- reduce_op (str, 可选)– 用于将Partial DTensor转换为Replicated/Sharded DTensor的归约操作。仅支持逐元素的归约操作,包括:“sum”、“avg”、“product”、“max”、“min”,默认值为"sum"。
注意:Partial布局可能作为DTensor运算符的结果生成,且只能通过DTensor.from_local API使用。
reduce_op: str = 'sum'class torch
.distributed.tensor.placement_types.PlacementPlacement 类型的基类,用于描述如何将 DTensor 放置在 DeviceMesh 上。Placement 和 DeviceMesh 共同定义了 DTensor 的布局。
它是三种主要 DTensor 放置类型(Shard、Replicate 和 Partial)的基类。
这个类不直接使用,主要作为类型标注存根。
is_partial()返回类型:bool
is_replicate()返回类型:bool
is_shard(dim=None)返回类型:bool
创建 DTensor 的不同方式
有三种方法可以构建 DTensor:
- distribute_tensor()从每个 rank 上的逻辑或"全局"- torch.Tensor创建- DTensor。这可用于对叶子节点- torch.Tensor(即模型参数/缓冲区和输入)进行分片。
- DTensor.from_local()从每个 rank 上的本地- torch.Tensor创建- DTensor,可用于从非叶子节点- torch.Tensor(即前向/反向传播过程中的中间激活张量)创建- DTensor。
- DTensor 提供了专门的张量工厂函数(如 empty()、ones()、randn()等),通过直接指定DeviceMesh和Placement来创建不同的DTensor。与distribute_tensor()相比,这种方法可以直接在设备上实现分片内存,而不是在初始化逻辑张量内存后再执行分片操作。
从逻辑上的 torch.Tensor 创建 DTensor
torch.distributed 中的 SPMD(单程序多数据)编程模型会启动多个进程(例如通过 torchrun)来执行同一程序。这意味着程序内部的模型会先在不同进程上初始化(例如模型可能在 CPU、元设备上初始化,或者如果有足够内存则直接在 GPU 上初始化)。
DTensor 提供了一个 distribute_tensor() API,可以将模型权重或张量分片为多个 DTensor。该 API 会在每个进程上从“逻辑”张量创建 DTensor,从而使生成的 DTensor 遵循单一设备语义,这对于数值正确性至关重要。
torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)根据指定的placements将叶子节点torch.Tensor(如nn.Parameter/缓冲区)分发到device_mesh。device_mesh和placements的维度必须相同。待分发的tensor是逻辑或"全局"张量,该API会使用DeviceMesh第一个维度的首秩张量作为数据源以保持单设备语义。若需在自动梯度计算过程中构建DTensor,请改用DTensor.from_local()。
参数说明
- tensor (torch.Tensor)– 待分发的张量。注意:若需在设备网格维度上对无法整除的张量进行分片,将使用- torch.chunk语义进行分片和散射。非均匀分片行为尚处实验阶段,后续可能变更。
- device_mesh (DeviceMesh- , 可选)– 目标设备网格。若未指定,必须在DeviceMesh上下文管理器中调用,默认值:None
- placements (List[Placement- ], 可选)– 描述张量在设备网格上分布方式的定位策略,元素数量必须与- device_mesh.ndim相同。若未指定,默认会沿设备网格各维度的首秩复制张量。
关键字参数
- src_data_rank ( int , 可选)– 逻辑/全局张量的源数据秩,- distribute_tensor()通过此参数将分片/副本散射/广播到其他秩。默认使用各DeviceMesh维度上- group_rank=0作为数据源以保持单设备语义。若显式传入- None,该API将直接使用本地数据而非通过散射/广播保持单设备语义。默认值:0
返回值
 返回DTensor或XLAShardedTensor对象。
返回类型
DTensor
注意:当使用xla设备类型初始化DeviceMesh时,distribute_tensor会返回XLAShardedTensor。详见此问题。XLA集成功能尚处实验阶段,后续可能变更。
除distribute_tensor()外,DTensor还提供distribute_module()API,可在nn.Module层级实现更便捷的分片操作。
torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)该函数提供了三个功能来控制模块的参数/输入/输出:
1、通过在运行时执行前指定 partition_fn 对模块进行分片处理(即允许用户根据指定的 partition_fn 将模块参数转换为 DTensor 参数)。
2、通过在运行时执行时指定 input_fn 和 output_fn 来控制模块的输入或输出(即将输入转换为 DTensor,将输出转换回 torch.Tensor)。
参数
- module (nn.Module- )– 需要分片的用户模块。
- device_mesh (DeviceMesh- )– 用于放置模块的设备网格。
- partition_fn (Callable)– 用于分片参数的函数(即在- device_mesh上切分特定参数)。如果未指定- partition_fn,默认会在网格上复制- module的所有模块参数。
- input_fn (Callable)– 指定输入分布,即可以控制模块输入的切分方式。- input_fn会作为模块的- forward_pre_hook(前向钩子)安装。
- output_fn (Callable)– 指定输出分布,即可以控制输出的切分方式,或将其转换回 torch.Tensor。- output_fn会作为模块的- forward_hook(后向钩子)安装。
返回
一个包含所有参数/缓冲区的模块,这些参数/缓冲区均为 DTensor 类型。
返回类型:Module
注意:当使用 xla 设备类型初始化 DeviceMesh 时,distribute_module 会返回带有 PyTorch/XLA SPMD 注释参数的 nn.Module。详情请参阅此问题。XLA 集成目前处于实验阶段,可能会发生变化。
DTensor 工厂函数
DTensor 还提供了专门的张量工厂函数,允许直接创建 DTensor。这些函数使用类似 torch.Tensor 的工厂函数 API(例如 torch.ones、torch.empty 等),并通过额外指定 DeviceMesh 和 Placement 来配置所创建的 DTensor:
torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)返回一个用标量值0填充的DTensor。
参数
- size ( int *...)- 定义输出- DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:zeros(1,2,3…) 或 zeros([1,2,3…]) 或 zeros((1,2,3…))
关键字参数
- requires_grad ([bool], 可选)- 如果为True,自动微分将记录对返回- DTensor的操作。默认值:- False。
- dtype (torch.dtype- , 可选)- 返回- DTensor的期望数据类型。默认值:如果为- None,则使用全局默认值(参见- torch.set_default_dtype())。
- layout ([torch.layout- ](tensor_attributes.html#torch.layout "torch.layout"), 可选)- 返回- DTensor的期望布局。默认值:- torch.strided。
- device_mesh-- DeviceMesh类型,包含rank的网格信息
- placements-- Placement类型的序列:- Shard、- Replicate
返回
每个rank上的一个DTensor对象
返回类型
torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)返回一个填充了标量值1的DTensor,其形状由可变参数size定义。
参数
- size ( int *...)– 定义输出- DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))
关键字参数
- dtype (torch.dtype- , 可选)– 返回- DTensor的期望数据类型。默认值:如果为- None,则使用全局默认值(参见- torch.set_default_dtype())。
- layout ([torch.layout- ](tensor_attributes.html#torch.layout "torch.layout"), 可选)– 返回DTensor的期望布局。默认值:- torch.strided。
- requires_grad ([bool], 可选)– 是否应在返回的- DTensor上记录自动梯度操作。默认值:- False。
- device_mesh–- DeviceMesh类型,包含进程的网格信息
- placements–- Placement类型的序列:- Shard、- Replicate
返回值:每个进程上的一个DTensor对象
返回类型:DTensor
torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)返回一个填充了未初始化数据的 DTensor。该 DTensor 的形状由可变参数 size 定义。
参数
- size ( int *...)– 定义输出- DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:empty(1,2,3…)、empty([1,2,3…]) 或 empty((1,2,3…))。
关键字参数
- dtype (torch.dtype- , 可选)– 返回- DTensor的期望数据类型。默认值:如果为- None,则使用全局默认值(参见- torch.set_default_dtype())。
- layout ([torch.layout- ](tensor_attributes.html#torch.layout "torch.layout"), 可选)– 返回- DTensor的期望布局。默认值:- torch.strided。
- requires_grad ([bool], 可选)– 是否在返回的- DTensor上记录自动求导操作。默认值:- False。
- device_mesh–- DeviceMesh类型,包含进程的网格信息。
- placements–- Placement类型的序列:- Shard、- Replicate。
返回值:每个进程上的一个 DTensor 对象。
返回类型:DTensor
torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)根据 device_mesh 和 placements 参数,返回一个填充了 fill_value 的 DTensor,其形状由参数 size 定义。
参数
- size ( int *...)– 定义输出- DTensor形状的整数序列。可以是可变数量的参数,也可以是列表或元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))。
- fill_value (Scalar)– 用于填充输出张量的值。
关键字参数
- dtype (torch.dtype- , 可选)– 返回的- DTensor所需的数据类型。默认值:如果为- None,则使用全局默认值(参见- torch.set_default_dtype())。
- layout ([torch.layout- ](tensor_attributes.html#torch.layout "torch.layout"), 可选)– 返回的 DTensor 所需的布局。默认值:- torch.strided。
- requires_grad ([bool], 可选)– 是否应自动梯度记录对返回的- DTensor的操作。默认值:- False。
- device_mesh–- DeviceMesh类型,包含 rank 的网格信息。
- placements–- Placement类型的序列:- Shard、- Replicate。
返回
每个 rank 上的一个 DTensor 对象。
返回类型
torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)返回一个填充了区间 [0, 1) 上均匀分布随机数的 DTensor。张量的形状由可变参数 size 定义。
参数
- size (int *...)– 定义输出- DTensor形状的整数序列。可以是可变数量的参数或类似列表或元组的集合。例如:ones(1,2,3…)、ones([1,2,3…]) 或 ones((1,2,3…))。
关键字参数
- dtype (torch.dtype, 可选)– 返回的- DTensor所需的数据类型。默认值:如果为- None,则使用全局默认值(参见- torch.set_default_dtype())。
- layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选)– 返回的 DTensor 所需的布局。默认值:- torch.strided。
- requires_grad ([bool], 可选)– 如果为- True,则自动微分会记录对返回的- DTensor的操作。默认值:- False。
- device_mesh–- DeviceMesh类型,包含进程的网格信息。
- placements–- Placement类型的序列:- Shard、- Replicate。
返回
每个进程上的一个 DTensor 对象。
返回类型
torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)返回一个填充了均值为0、方差为1的正态分布随机数的DTensor,张量的形状由变量参数size定义。
参数
- size (int *...)- 定义输出- DTensor形状的整数序列。可以是可变数量的参数或列表/元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))
关键字参数
- dtype (torch.dtype, 可选)- 返回- DTensor的期望数据类型。默认值:如果为- None,则使用全局默认值(参见- torch.set_default_dtype())。
- layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选)- 返回DTensor的期望布局。默认值:- torch.strided。
- requires_grad ([bool], 可选)- 是否应在返回的- DTensor上记录自动求导操作。默认值:- False。
- device_mesh-- DeviceMesh类型,包含rank的网格信息。
- placements-- Placement类型的序列:- Shard、- Replicate
返回
每个rank上的一个DTensor对象
返回类型
调试
日志记录
启动程序时,可以通过设置 torch._logging 中的 TORCH_LOGS 环境变量来启用额外的日志记录功能:
- TORCH_LOGS=+dtensor将显示 logging.DEBUG 及以上级别的日志消息
- TORCH_LOGS=dtensor将显示 logging.INFO 及以上级别的日志消息
- TORCH_LOGS=-dtensor将显示 logging.WARNING 及以上级别的日志消息
调试工具
为了调试应用了DTensor的程序,并深入了解底层发生的集合通信细节,DTensor提供了CommDebugMode调试模式:
class torch
.distributed.tensor.debug.CommDebugModeCommDebugMode 是一个上下文管理器,用于统计其上下文中功能集合操作的次数。它通过 TorchDispatchMode 实现这一功能。
注意:目前并非所有集合操作都受支持。
使用示例
mod = ...
comm_mode = CommDebugMode()
with comm_mode:
mod.sum().backward()
print(comm_mode.get_comm_counts())generate_comm_debug_tracing_table(noise_level=3)生成详细表格,展示模块层级的操作和集体追踪信息。信息量取决于 noise_level 参数:
0、打印模块层级的集体调用次数统计
1、打印未包含在简单操作中的 dTensor 操作及模块信息
2、打印未包含在简单操作中的所有操作
3、打印全部操作
generate_json_dump(file_name='comm_mode_log.json', noise_level=3)生成用于构建浏览器可视化的json文件
0、打印模块级别的聚合计数
1、打印未包含在简单操作中的dTensor运算
2、打印未包含在简单操作中的运算
3、打印所有运算
get_comm_counts()返回通信计数作为字典。
返回值:以字典形式返回通信计数。
返回类型:Dict[Any, int]
get_parameter_info()返回类型:dict[str , dict[str , Any ]
get_sharding_info()返回类型 : dict[str, dict[str, Any]]
get_total_counts()返回类型:int
log_comm_debug_tracing_table_to_file(file_name='comm_mode_log.txt', noise_level=3)替代控制台 CommDebugMode 输出的方案,可将日志写入用户指定的文件
为了可视化维度少于 3 的 DTensor 分片情况,DTensor 提供了 visualize_sharding() 方法:
torch.distributed.tensor.debug.visualize_sharding(dtensor, header='')在终端中可视化一维或二维 DTensor 的分片情况。
注意:需安装 tabulate 包。空张量不会显示分片信息。
实验性功能
DTensor 还提供了一系列实验性功能。这些功能要么处于原型开发阶段,要么基础功能已完成但正在收集用户反馈。如果您对这些功能有任何意见,请向 PyTorch 提交 issue。
torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)context_parallel 是一个实验性 API,用于实现上下文并行(CP)。该 API 执行两个操作:1) 将 SDPA(torch.nn.functional.scaled_dot_product_attention)替换为支持 CP 的版本;2) 沿序列维度对 buffers 进行分片,每个 rank 根据 mesh 保留对应的分片。
参数
- mesh (DeviceMesh- )– 用于上下文并行的设备网格。
- buffers (Optional[List[torch.Tensor]])– 依赖序列维度的缓冲区。例如输入批次、标签和位置嵌入缓冲区。这些缓冲区必须沿序列维度分片以确保准确性。分片操作会就地执行,缓冲区的形状在上下文中会发生变化。上下文结束后,缓冲区会恢复原状。可以通过- no_restore_buffers指定哪些缓冲区无需恢复。注意- buffers不应包含任何 nn.Parameter。
- buffer_seq_dims (Optional[List[int]])–- buffers的序列维度。
- no_restore_buffers (Optional[Set[torch.Tensor]])– 此集合中的缓冲区在上下文退出后不会被恢复。该集合必须是- buffers的子集。如果缓冲区在上下文退出后不再使用,可以将其加入此列表以避免额外的恢复时间。
返回类型
Generator
警告:torch.distributed._tensor.experimental.attention.context_parallel 是 PyTorch 中的原型功能。API 可能会发生变化。
torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)local_map() 是一个实验性 API,允许用户将 DTensor 传递给原本设计用于处理 torch.Tensor 的函数。其实现原理是提取 DTensor 的本地分量,调用目标函数,然后根据 out_placements 将输出重新封装为 DTensor。
参数说明
- func (Callable)– 需要应用于每个- DTensor本地分片的函数
- out_placements (Union [PlacementType, Tuple[PlacementType, …]])– 函数展平输出中- DTensor的目标分布位置:- 当展平输出为单个值时,out_placements应为 PlacementType 类型
- 当展平输出包含多个值时,out_placements应为与输出值一一对应的 PlacementType 元组
- 对于 Tensor输出,使用 PlacementType 作为其分布位置(即 Tuple[Placement] 值)
- 对于非 Tensor 输出,PlacementType 应为 None
 
- 当展平输出为单个值时,
注意:当没有传入 DTensor 参数时,即使 out_placements 不为 None,结果函数也应忽略目标分布位置,因为此时函数并非运行在 DTensor 上。
- in_placements (Tuple[PlacementType, …], optional)– 函数展平输入中- DTensor的必需分布位置:- 指定时,local_map()会检查每个DTensor参数的分布位置是否符合要求
- 当分布位置不符且 redistribute_inputs=False时会抛出异常
- 当 redistribute_inputs=True时,参数会先重分布到要求的分片位置再传递给函数
- 例外情况:当必需分布位置非 None 但参数是 torch.Tensor时,跳过分布检查直接传递参数
- 默认值:None
 
- 指定时,
- device_mesh (DeviceMesh- , optional)– 所有- DTensor所处的设备网格。未指定时从输入- DTensor的设备网格推断。要求所有- DTensor必须位于同一设备网格。默认值:None
- redistribute_inputs ([bool], optional)– 布尔值,指示当输入- DTensor分布位置与要求不符时是否进行重分布:- 为 False 且需要重分布时会抛出异常
- 默认值:False
 
返回值
 返回一个可调用对象,该对象会将 func 应用于输入 DTensor 的每个本地分片,并将函数返回值构造成 DTensor。
异常情况
- AssertionError– 当出现以下情况时触发:- 输入 DTensor不在同一设备网格
- 输入 DTensor与device_mesh参数指定的设备网格不同
- 非 Tensor 输出对应的 out_placements不为 None
 
- 输入 
- ValueError– 当- redistribute_inputs=False但输入- DTensor需要根据- in_placements重分布时触发
示例:
>>
>
def mm_allreduce_forward(device_mesh, W, X):
>>
> partial_sum_tensor = torch.mm(W, X)
>>
> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
>>
>
return reduced_tensor
>>
>
>>
> W = torch.randn(12, 8, requires_grad=False)
>>
> X = torch.randn(8, 16, requires_grad=False)
>>
> Y = torch.mm(W, X)
>>
> row_wise = [Shard(0)] # 在一维网格上的行分片布局
>>
> col_wise = [Shard(1)] # 在一维网格上的列分片布局
>>
>
>>
>
# local_mm_allreduce_forward是封装了DTensor/Tensor转换的函数
>>
> local_mm_allreduce_forward = local_map(
>>
> mm_allreduce_forward,
>>
> out_placements=[Replicate()],
>>
> in_placements=[col_wise, row_wise],
>>
> device_mesh=device_mesh,
>>
>
)
>>
>
>>
> W_dt = distribute_tensor(
... W, device_mesh, (col_wise)
... ) # 列分片的W张量
>>
> X_dt = distribute_tensor(
... X, device_mesh, (row_wise)
... ) # 行分片的X张量
>>
> Y_dt = local_mm_allreduce_forward(
... device_mesh, W_dt, X_dt
... ) # 对DTensors应用local_mm_allreduce_forward注意:此 API 目前处于实验阶段,可能会发生变化
torch.distributed.tensor.experimental.register_sharding(op)register_sharding() 是一个实验性 API,允许用户在张量输入输出为 DTensor 时,为运算符注册分片策略。
该 API 在以下场景中特别有用:(1) 当 op 不存在默认分片策略时(例如 op 是 DTensor 不支持的定制运算符);(2) 当用户希望覆盖现有运算符的默认分片策略时。
参数说明
- op (Union[OpOverload*, List[OpOverload]])—— 需要注册自定义分片函数的单个运算符或运算符列表。
返回值
返回一个函数装饰器,可用于包装定义 op 所指定运算符分片策略的函数。定义的分片策略将被注册到 DTensor 中,若 DTensor 已实现该运算符,则会覆盖其默认分片策略。自定义分片函数的输入参数与原运算符相同(若参数为 torch.Tensor 则会被替换为 DTensor 内部使用的类张量对象)。该函数应返回由二元组构成的序列,每个二元组分别指定可接受的输出布局及其对应的输入布局。
使用示例
>>
> @register_sharding(aten._softmax.default)
>>
>
def custom_softmax_sharding(x, dim, half_to_float):
>>
> softmax_dim = dim if dim >= 0 else dim + x.ndim
>>
> acceptable_shardings = []
>>
>
>>
> all_replicate = ([Replicate()], [Replicate(), None, None])
>>
> acceptable_shardings.append(all_replicate)
>>
>
>>
>
for sharding_dim in range(x.ndim):
>>
>
if sharding_dim != softmax_dim:
>>
> all_sharded = (
>>
>
[Shard(sharding_dim)],
>>
>
[Shard(sharding_dim), None, None],
>>
>
)
>>
> acceptable_shardings.append(all_sharded)
>>
>
>>
>
return acceptable_shardings注意:此 API 目前处于实验阶段,后续可能会发生变化
torch.distributed.tensor
注意:torch.distributed.tensor 目前处于 alpha 开发阶段,文档中列出的大部分 API 我们将确保向后兼容性,但必要时可能会进行 API 变更。
PyTorch DTensor(分布式张量)
PyTorch DTensor 提供简单灵活的张量分片原语,能够透明处理分布式逻辑,包括跨设备/主机的分片存储、算子计算和集合通信。DTensor 可用于构建不同的并行解决方案,并支持在多维分片场景下表示分片状态的 state_dict。
以下是基于 DTensor 构建的 PyTorch 原生并行方案示例:
DTensor 遵循 SPMD(单程序多数据)编程模型,让用户能够像编写具有相同收敛特性的单设备程序那样编写分布式程序。它通过指定 DeviceMesh 和 Placement 提供统一的张量分片布局(DTensor 布局):
- DeviceMesh使用 n 维数组表示集群的设备拓扑和通信器
- Placement描述逻辑张量在- DeviceMesh上的分片布局
 DTensor 支持三种分片类型:- Shard(分片)、- Replicate(复制)和- Partial(部分)。
DTensor 类 API
DTensor 是 torch.Tensor 的子类。这意味着一旦创建了 DTensor,就可以以与 torch.Tensor 非常相似的方式使用它,包括运行不同类型的 PyTorch 操作符,就像在单个设备上运行它们一样,同时为 PyTorch 操作符提供正确的分布式计算支持。
除了现有的 torch.Tensor 方法外,它还提供了一组额外的方法来与 torch.Tensor 交互、将 DTensor 布局重新分配到新的 DTensor、获取所有设备上的完整张量内容等。
class torch
.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)DTensor (Distributed Tensor) is a subclass of torch.Tensor that provides single-device like
 abstraction to program with multi-device torch.Tensor. It describes the distributed tensor sharding
 layout (DTensor Layout) through the DeviceMesh and following types of Placement:
- Shard: Tensor sharded on the tensor dimension- dimon the devices of the- DeviceMeshdimension
- Replicate: Tensor replicated on the devices of the- DeviceMeshdimension
- Partial: Tensor is pending reduction on the devices of the- DeviceMeshdimension
When calling PyTorch operators, DTensor overrides the PyTorch operators to perform sharded computation and issue
 communications whenever necessary. Along with the operator computation, DTensor will transform or propagate the placements (DTensor Layout) properly (based on the operator semantic itself) and generate new DTensor outputs.
To ensure numerical correctness of the DTensor sharded computation when calling PyTorch operators, DTensor
 requires every Tensor argument of the operator be DTensor.
Note: Directly using the Tensor subclass constructor here is not the recommended way to create a DTensor
 (i.e. it does not handle autograd correctly hence is not the public API). Please refer to the create_dtensor
 section to see how to create a DTensor.
Return type
DTensor
__create_chunk_list__()返回一个 ChunkStorageMetadata 列表,该数据类用于描述当前 rank 上本地分片/副本的大小和偏移量。对于 DTensor,每个 rank 只会有一个本地分片/副本,因此返回的列表通常仅包含一个元素。
此双下划线方法主要用于分布式检查点用途。
返回值:一个 List[ChunkStorageMetadata] 对象,表示当前 rank 上的分片大小/偏移量。
static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)根据指定的 device_mesh 和 placements,从各 rank 上的本地 torch.Tensor 创建一个 DTensor
参数
- local_tensor (torch.Tensor)– 各 rank 上的本地 torch.Tensor。
- device_mesh (DeviceMesh- , 可选)– 用于放置张量的 DeviceMesh。若未指定,则必须在 DeviceMesh 上下文管理器中调用,默认值:None
- placements (List[Placement- ], 可选)– 描述如何将本地 torch.Tensor 放置在 DeviceMesh 上的布局列表,其元素数量必须与- device_mesh.ndim相同。
关键字参数
- run_check ([bool], 可选)– 以额外通信为代价,跨 rank 执行完整性检查,验证各本地张量的元信息以确保正确性。若- placements中包含- Replicate,设备网格维度的第一个 rank 上的数据将被广播到其他 rank。默认值:False
- shape ( torch.Size , 可选)– 指定构建在 local_tensor 之上的 DTensor 大小的整型列表。注意:当各 rank 上- local_tensor的形状不同时必须提供此参数。若未提供,将假设给定的分布式张量均匀分片到各 rank 来计算- shape。默认值:None
- stride ( tuple , 可选)– 指定 DTensor 步长的整型列表。若未提供,将假设给定的分布式张量均匀分片到各 rank 来计算- stride。默认值:None
返回
一个 DTensor 对象
返回类型:DTensor
注意:当 run_check=False 时,用户需自行确保传入的本地张量在各 rank 间正确(即对于 Shard(dim) 布局张量需分片,对于 Replicate() 布局需复制)。否则,所创建 DTensor 的行为将是未定义的。
注意:from_local 是可微操作,创建的 DTensor 对象的 requires_grad 属性将取决于 local_tensor 是否 requires_grad。
full_tensor(*, grad_placements=None)Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate
 them together. It’s a syntatic sugar of the following code:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()
Keyword Arguments
- grad_placements (List[Placement- ], optional)– the placements describes the future layout of any gradient layout of the full Tensor returned from this function.
 full_tensor converts DTensor to a full torch.Tensor and the returned torch.tensor
 might not be used as the original replicated DTensor layout later in the code. This
 argument is the hint that user can give to autograd in case the gradient
 layout of the returned tensor does not match the original replicated DTensor layout.
 If not specified, we will assume the gradient layout of the full tensor be replicated.
Returns
 A torch.Tensor object that represents the full tensor of this DTensor.
Return type : Tensor 
Note: full_tensor is differentiable.
redistribute(device_mesh=None, placements=None, *, async_op=False)redistribute 执行必要的集体操作,将当前 DTensor 从其现有布局重新分配到新布局,或从当前 DeviceMesh 迁移到新 DeviceMesh。例如,我们可以通过为 DeviceMesh 的每个维度指定 Replicate 布局,将分片(Sharded)DTensor 转换为复制(Replicated)DTensor。
当在 DeviceMesh 的某个维度上从当前布局重新分配到新布局时,将执行以下包含通信集体操作或本地操作:
1、Shard(dim) → Replicate():all_gather
 2、Shard(src_dim) → Shard(dst_dim):all_to_all
 3、Replicate() → Shard(dim):本地分块(即 torch.chunk)
 4、Partial() → Replicate():all_reduce
 5、Partial() → Shard(dim):reduce_scatter
redistribute 能够正确推断出针对在 1-D 或 N-D DeviceMesh 上创建的 DTensor 所需的重新分配步骤。
参数
- device_mesh (DeviceMesh- , 可选)– 用于放置 DTensor 的 DeviceMesh。若未指定,则使用当前 DTensor 的 DeviceMesh。
 默认值:None
- placements (List[Placement- ], 可选)– 描述如何将 DTensor 放置到 DeviceMesh 中的新布局,其元素数量必须与- device_mesh.ndim相同。
 默认值:在所有网格维度上复制(replicate)
关键字参数
- async_op ([bool], 可选)– 是否以异步方式执行 DTensor 重新分配操作。默认值:False
返回
 一个 DTensor 对象
返回类型
DTensor
注意:redistribute 是可微分的,这意味着用户无需担心重新分配操作的反向传播公式。
注意:redistribute 当前仅支持在同一 DeviceMesh 上重新分配 DTensor。若需将 DTensor 重新分配到不同 DeviceMesh,请提交问题。
to_local(*, grad_placements=None)Get the local tensor of this DTensor on its current rank. For sharding it returns a local shard of the logical tensor view, for replication it returns the replica on its current rank.
Keyword Arguments
- grad_placements (List[Placement- ], optional)– the placements describes the future layout of any gradient layout of the Tensor returned from this function.
 to_local converts DTensor to local tensor and the returned local tensor
 might not be used as the original DTensor layout later in the code. This
 argument is the hint that user can give to autograd in case the gradient
 layout of the returned tensor does not match the original DTensor layout.
 If not specified, we will assume the gradient layout remains the same as the original DTensor and use that for gradient computation.
Returns
 A torch.Tensor or AsyncCollectiveTensor object. it represents the local tensor on its current rank. When an AsyncCollectiveTensor object is returned, it means the local tensor is not ready yet (i.e. communication is not finished). In this case, user needs to call wait to wait the local tensor to be ready.
Return type : Tensor 
Note: to_local is differentiable, the requires_grad of the local tensor returned
 will depend on if the DTensor requires_grad or not.
property device_mesh: [DeviceMesh](distributed.html#torch.distributed.device_mesh.DeviceMesh "torch.distributed.device_mesh.DeviceMesh")The DeviceMesh attribute that associates with this DTensor object.
Note: device_mesh is a read-only property, it can not be set.
property placements: tuple [[torch.distributed.tensor.placement_types.Placement](https://pytorch.org/docs/stable/data.html#torch.distributed.tensor.placement_types.Placement "torch.distributed.tensor.placement_types.Placement"),
...]该 DTensor 的 placements 属性描述了其在设备网格(DeviceMesh)上的分布布局。
注意:placements 是只读属性,不可被修改。
作为分布式通信器的DeviceMesh
DeviceMesh基于DTensor构建,用于抽象描述集群设备拓扑结构,并作为多维通信器(基于ProcessGroup)的载体。如需了解如何创建/使用DeviceMesh的具体细节,请参阅DeviceMesh使用指南。
DTensor 布局类型
DTensor 支持在每个 DeviceMesh 维度上使用以下 Placement 类型:
class torch
.distributed.tensor.placement_types.Shard(dim)Shard(dim)布局描述了张量在维度dim上跨对应DeviceMesh维度的分片方式,其中DeviceMesh维度上的每个rank仅持有全局张量的一个分片。Shard(dim)布局遵循torch.chunk(dim)语义——当张量维度无法在DeviceMesh维度上均匀划分时,DeviceMesh维度上的最后几个分片可能为空。所有DTensor API(如distribute_tensor、from_local等)均可使用Shard布局。
参数
- dim (int)- 指定张量在对应DeviceMesh维度上进行分片的维度编号。
警告:当前对无法在DeviceMesh维度上均匀划分的张量维度进行分片属于实验性功能,后续可能变更。
dim: int
class torch
.distributed.tensor.placement_types.ReplicateThe Replicate() placement describes the DTensor replicating on a corresponding DeviceMesh dimension, where each rank on the DeviceMesh dimension holds a replica of the global Tensor. The Replicate placement can be used by all DTensor APIs (i.e. distribute_tensor, DTensor.from_local, etc.)
class torch
.distributed.tensor.placement_types.Partial(reduce_op='sum')The Partial(reduce_op) placement describes the DTensor that is pending reduction on a specified DeviceMesh dimension, where each rank on the DeviceMesh dimension holds the partial value of the global Tensor. User can redistribute the Partial DTensor to a Replicate or Shard(dim)
 placement on the specified DeviceMesh dimension using redistribute, which would trigger necessary communication operations under the hood (i.e. allreduce, reduce_scatter).
Parameters
- reduce_op (str, optional)– The reduction op to be used for the partial DTensor to produce Replicated/Sharded DTensor. Only element-wise reduction operations are supported, including: “sum”, “avg”, “product”, “max”, “min”, default: “sum”.
Note: The Partial placement can be generated as a result of the DTensor operators, and can only be used by the DTensor.from_local API.
reduce_op: str = 'sum'class torch
.distributed.tensor.placement_types.PlacementPlacement 类型的基类,用于描述如何将 DTensor 放置在 DeviceMesh 上。Placement 和 DeviceMesh 共同定义了 DTensor 的布局。
它是三种主要 DTensor 放置类型(Shard、Replicate 和 Partial)的基类。
这个类不直接使用,主要作为类型标注存根。
is_partial()Return type : bool
is_replicate()返回类型:bool
is_shard(dim=None)Return type : bool
Different ways to create a DTensor
There’re three ways to construct a DTensor distribute_tensor() creates a DTensor from a logical or “global” torch.Tensor on each rank. This could be used to shard the leaf torch.Tensor s (i.e. model parameters/buffers and inputs).
- DTensor.from_local()creates a- DTensorfrom a local- torch.Tensoron each rank, which can be used to create- DTensorfrom a non-leaf- torch.Tensors (i.e. intermediate activation
 tensors during forward/backward).
- DTensor provides dedicated tensor factory functions (e.g. empty(),ones(),randn(), etc.) to allow differentDTensorcreations by directly specifying theDeviceMeshandPlacement. Compare todistribute_tensor(), this could directly materializing the sharded memory on device, instead of performing sharding after initializing the logical Tensor memory.
Create DTensor from a logical torch.Tensor
The SPMD (single program, multiple data) programming model in torch.distributed launches multiple processes
 (i.e. via torchrun) to execute the same program, this means that the model inside the program would be initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly on GPU if enough memory).
DTensor offers a distribute_tensor() API that could shard the model weights or Tensors to DTensor s, where it would create a DTensor from the “logical” Tensor on each process. This would empower the createdDTensor s to comply with the single device semantic, which is critical for numerical correctness.
torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)Distribute a leaf torch.Tensor (i.e. nn.Parameter/buffers) to the device_mesh according to the placements specified. The rank of device_mesh and placements must be the same. The tensor to distribute is the logical or “global” tensor, and the API would use the tensor from first rank of the DeviceMesh dimension as the source of truth to preserve the single-device semantic. If you want to construct a DTensor in the middle of the Autograd
 computation, please use DTensor.from_local() instead.
Parameters
- tensor (torch.Tensor)– torch.Tensor to be distributed. Note that if you want to shard a tensor on a dimension that is not evenly divisible by the number of devices in that mesh dimension, we use- torch.chunk
 semantic to shard the tensor and scatter the shards. The uneven sharding
 behavior is experimental and subject to change.
- device_mesh (DeviceMesh- , optional)– DeviceMesh to distribute the tensor, if not specified, must be called under a DeviceMesh context
 manager, default: None
- placements (List[Placement- ], optional)– the placements that describes how to place the tensor on DeviceMesh, must have the same number of elements as- device_mesh.ndim. If not specified, we will by default replicate the tensor across the- device_meshfrom the first rank of each dimension of the device_mesh.
Keyword Arguments
- src_data_rank ( int , optional)– the rank of the source data for the logical/global tensor, it is used by- distribute_tensor()to scatter/broadcast the shards/replicas to other ranks. by default, we use- group_rank=0on each DeviceMesh dimension as the source data to preserve the single-device semantic. If passing- Noneexplicitly,- distribute_tensor()simply uses
 its local data instead of trying to preserve the single-device semantic via scatter/broadcast.
 Default: 0
Returns
 A DTensor or XLAShardedTensor object.
Return type
DTensor
Note: When initialize the DeviceMesh with the xla device_type, distribute_tensor
 return XLAShardedTensor instead. see this issuefor more details. The XLA integration is experimental and subject to change.
Along with distribute_tensor(), DTensor also offers a distribute_module() API to allow easier
 sharding on the nn.Module level
torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)该函数提供了三个功能来控制模块的参数/输入/输出:
1、通过在运行时执行前指定 partition_fn 对模块进行分片处理(即允许用户根据指定的 partition_fn 将模块参数转换为 DTensor 参数)。
2、通过在运行时执行时指定 input_fn 和 output_fn 来控制模块的输入或输出(即将输入转换为 DTensor,将输出转换回 torch.Tensor)。
参数
- module (nn.Module- )– 需要分片的用户模块。
- device_mesh (DeviceMesh- )– 用于放置模块的设备网格。
- partition_fn (Callable)– 用于分片参数的函数(即在- device_mesh上切分特定参数)。如果未指定- partition_fn,默认会在网格上复制- module的所有模块参数。
- input_fn (Callable)– 指定输入分布,即可以控制模块输入的切分方式。- input_fn会作为模块的- forward_pre_hook(前向钩子)安装。
- output_fn (Callable)– 指定输出分布,即可以控制输出的切分方式,或将其转换回 torch.Tensor。- output_fn会作为模块的- forward_hook(后向钩子)安装。
返回
一个包含所有参数/缓冲区的模块,这些参数/缓冲区均为 DTensor 类型。
返回类型:Module
注意:当使用 xla 设备类型初始化 DeviceMesh 时,distribute_module 会返回带有 PyTorch/XLA SPMD 注释参数的 nn.Module。详情请参阅此问题。XLA 集成目前处于实验阶段,可能会发生变化。
DTensor 工厂函数
DTensor 还提供了专门的张量工厂函数,允许直接创建 DTensor。这些函数使用类似 torch.Tensor 的工厂函数 API(例如 torch.ones、torch.empty 等),并通过额外指定 DeviceMesh 和 Placement 来配置所创建的 DTensor:
torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)Returns a DTensor filled with the scalar value 0.
Parameters
- size ( int *...)– a sequence of integers defining the shape of the output- DTensor.
 Can be a variable number of arguments or a collection like a list or tuple.
 E.g.: zeros(1,2,3…) or zeros([1,2,3…]) or zeros((1,2,3…))
Keyword Arguments
- requires_grad ([bool], optional)– If autograd should record operations on the returned- DTensor. Default:- False.
- dtype (torch.dtype- , optional)– the desired data type of returned- DTensor.
 Default: if- None, uses a global default (see- torch.set_default_dtype()).
- layout ([torch.layout- ](tensor_attributes.html#torch.layout "torch.layout"), optional)– the desired layout of returned- DTensor.
 Default:- torch.strided.
- device_mesh–- DeviceMeshtype, contains the mesh info of ranks
- placements– a sequence of- Placementtype:- Shard,- Replicate
Returns
 A DTensor object on each rank
Return type
DTensor
torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)返回一个填充了标量值1的DTensor,其形状由可变参数size定义。
参数
- size ( int *...)– 定义输出- DTensor形状的整数序列。可以是可变数量的参数或列表、元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))
关键字参数
- dtype (torch.dtype- , 可选)– 返回- DTensor的期望数据类型。默认值:如果为- None,则使用全局默认值(参见- torch.set_default_dtype())。
- layout ([torch.layout- ](tensor_attributes.html#torch.layout "torch.layout"), 可选)– 返回DTensor的期望布局。默认值:- torch.strided。
- requires_grad ([bool], 可选)– 是否应在返回的- DTensor上记录自动梯度操作。默认值:- False。
- device_mesh–- DeviceMesh类型,包含进程的网格信息
- placements–- Placement类型的序列:- Shard、- Replicate
返回值:每个进程上的一个DTensor对象
返回类型:DTensor
torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)Returns a DTensor filled with uninitialized data. The shape of the DTensor is defined by the variable argument size.
Parameters
- size ( int *...)– a sequence of integers defining the shape of the output- DTensor.
 Can be a variable number of arguments or a collection like a list or tuple.
 E.g.: empty(1,2,3…) or empty([1,2,3…]) or empty((1,2,3…))
Keyword Arguments
- dtype (torch.dtype- , optional)– the desired data type of returned- DTensor.
 Default: if- None, uses a global default (see- torch.set_default_dtype()). layout (- torch.layout, optional): the desired layout of returned- DTensor.
 Default:- torch.strided.
- requires_grad ([bool], optional)– If autograd should record operations on the returned- DTensor. Default:- False.
- device_mesh–- DeviceMeshtype, contains the mesh info of ranks
- placements– a sequence of- Placementtype:- Shard,- Replicate
Returns
 A DTensor object on each rank
Return type
DTensor
torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)根据 device_mesh 和 placements 参数,返回一个填充了 fill_value 的 DTensor,其形状由参数 size 定义。
参数
- size ( int *...)– 定义输出- DTensor形状的整数序列。可以是可变数量的参数,也可以是列表或元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))。
- fill_value (Scalar)– 用于填充输出张量的值。
关键字参数
- dtype (torch.dtype- , 可选)– 返回的- DTensor所需的数据类型。默认值:如果为- None,则使用全局默认值(参见- torch.set_default_dtype())。
- layout ([torch.layout- ](tensor_attributes.html#torch.layout "torch.layout"), 可选)– 返回的 DTensor 所需的布局。默认值:- torch.strided。
- requires_grad ([bool], 可选)– 是否应自动梯度记录对返回的- DTensor的操作。默认值:- False。
- device_mesh–- DeviceMesh类型,包含 rank 的网格信息。
- placements–- Placement类型的序列:- Shard、- Replicate。
返回
每个 rank 上的一个 DTensor 对象。
返回类型
torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)Returns a DTensor filled with random numbers from a uniform distribution on the interval [0, 1). The shape of the tensor is defined by the variable
 argument size.
Parameters
- size ( int *...)– a sequence of integers defining the shape of the output- DTensor.
 Can be a variable number of arguments or a collection like a list or tuple.
 E.g.: ones(1,2,3…) or ones([1,2,3…]) or ones((1,2,3…))
Keyword Arguments
- dtype (torch.dtype- , optional)– the desired data type of returned- DTensor.
 Default: if- None, uses a global default (see- torch.set_default_dtype()).
- layout ([torch.layout- ](tensor_attributes.html#torch.layout "torch.layout"), optional)– the desired layout of returned DTensor.
 Default:- torch.strided.
- requires_grad ([bool], optional)– If autograd should record operations on the returned- DTensor. Default:- False.
- device_mesh–- DeviceMeshtype, contains the mesh info of ranks.
- placements– a sequence of- Placementtype:- Shard,- Replicate
Returns
 A DTensor object on each rank
Return type
DTensor
torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)返回一个填充了均值为0、方差为1的正态分布随机数的DTensor,张量的形状由变量参数size定义。
参数
- size (int *...)- 定义输出- DTensor形状的整数序列。可以是可变数量的参数或列表/元组等集合。例如:ones(1,2,3…) 或 ones([1,2,3…]) 或 ones((1,2,3…))
关键字参数
- dtype (torch.dtype, 可选)- 返回- DTensor的期望数据类型。默认值:如果为- None,则使用全局默认值(参见- torch.set_default_dtype())。
- layout ([torch.layout](tensor_attributes.html#torch.layout "torch.layout"), 可选)- 返回DTensor的期望布局。默认值:- torch.strided。
- requires_grad ([bool], 可选)- 是否应在返回的- DTensor上记录自动求导操作。默认值:- False。
- device_mesh-- DeviceMesh类型,包含rank的网格信息。
- placements-- Placement类型的序列:- Shard、- Replicate
返回
每个rank上的一个DTensor对象
返回类型
调试
日志记录
启动程序时,可以通过设置 torch._logging 中的 TORCH_LOGS 环境变量来启用额外的日志记录功能:
- TORCH_LOGS=+dtensor将显示 logging.DEBUG 及以上级别的日志消息
- TORCH_LOGS=dtensor将显示 logging.INFO 及以上级别的日志消息
- TORCH_LOGS=-dtensor将显示 logging.WARNING 及以上级别的日志消息
调试工具
为了调试应用了DTensor的程序,并深入了解底层发生的集合通信细节,DTensor提供了CommDebugMode调试模式:
class torch
.distributed.tensor.debug.CommDebugModeCommDebugMode is a context manager that counts the number of functional collectives within its context. It does this using a TorchDispatchMode.
Note: Not all collectives are supported yet.
Example usage
mod = ...
comm_mode = CommDebugMode()
with comm_mode:
mod.sum().backward()
print(comm_mode.get_comm_counts())generate_comm_debug_tracing_table(noise_level=3)生成详细表格,展示模块层级的操作和集体追踪信息。信息量取决于 noise_level 参数:
0、打印模块层级的集体调用次数统计
1、打印未包含在简单操作中的 dTensor 操作及模块信息
2、打印未包含在简单操作中的所有操作
3、打印全部操作
generate_json_dump(file_name='comm_mode_log.json', noise_level=3)Creates json file used to build browser visual
 0、prints module-level collective counts
 1、prints dTensor operations not included in trivial operations
 2、prints operations not included in trivial operations
 3、prints all operations
get_comm_counts()返回通信计数作为字典。
返回值:以字典形式返回通信计数。
返回类型:Dict[Any, int]
get_parameter_info()Return type : dict[str , dict[str , Any ]
get_sharding_info()返回类型 : dict[str, dict[str, Any]]
get_total_counts()Return type : int
log_comm_debug_tracing_table_to_file(file_name='comm_mode_log.txt', noise_level=3)替代控制台 CommDebugMode 输出的方案,可将日志写入用户指定的文件
为了可视化维度少于 3 的 DTensor 分片情况,DTensor 提供了 visualize_sharding() 方法:
torch.distributed.tensor.debug.visualize_sharding(dtensor, header='')Visualizes sharding in the terminal for DTensor that are 1D or 2D.
Note: This requires the tabulate package. No sharding info will be printed for empty tensors
Experimental Features
DTensor also provides a set of experimental features. These features are either in prototyping stage, or the basic
 functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to these features.
torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)context_parallel is an experimental API to enable context
 parallelism (CP). This API performs two actions: 1) patch the SDPA
 (torch.nn.functional.scaled_dot_product_attention) with the CP-enabled
 one, 2) shard buffers along the sequence dimension and each rank will
 preserve the corresponding shard according mesh.
Parameters
- mesh (DeviceMesh- )– the device mesh for the context parallelism.
- buffers (Optional[List[torch.Tensor]])– buffers that the usage depend on the sequence dimension. Examples are input batch, labels and positional embedding buffers. These buffers must be sharded along the sequence dimension to ensure the accuracy. The sharding will
 happen in-place, the buffer’s shape will change within the context.
 The buffers will be restored after the context finishes.- no_restore_bufferscan be used to specify which buffers don’t
 need to be restored. Note that- buffersshould not contain any
 nn.Parameter.
- buffer_seq_dims (Optional[List[int]])– the sequence dimensions of- buffers.
- no_restore_buffers (Optional[Set[torch.Tensor]])– buffers in these set
 won’t be restored after the context exits. This set must be a subset of- buffers. If the buffers won’t be used after the context exits, these buffers can be put in this list to avoid extra restore time.
Return type
 Generator
Warning: torch.distributed._tensor.experimental.attention.context_parallel is a prototype feature in PyTorch. The API is subject to change.
torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)local_map() is an experimental API that allows users to pass DTensor s to a function that is written to be applied on torch.Tensor s. It is done by extracting the local components of DTensor, call the function, and wrap the outputs to DTensor according to the out_placements.
Parameters
- func (Callable)– the function to be applied on each local shard of- DTensors.
- out_placements (Union [PlacementType, Tuple[PlacementType, …]])– the desired placements of the- DTensors in- func’s flattened output.
 If the flattened- outputis a single value, the- out_placementsshould be of type PlacementType. Otherwise if the flattened- outputhas multiple
 values, the- out_placementsshould be a tuple of PlacementType values 1:1
 mapping to the flattened- output.
 Besides, for- Tensoroutput, we use PlacementType as its
 placements (a Tuple[Placement] value). For non-Tensor output, the PlacementType
 should be None.
 Note that the only exception is when no- DTensorargument is passed
 in. In this case, even if out_placements is not None, the result function
 should ignore the desired placements because the function is not running with- DTensors.
- in_placements (Tuple[PlacementType, …], optional)– the required placements of the- DTensors in the flattened inputs of- func.
 If- in_placementsis specified,- local_map()would examine whether the placements of each- DTensorargument is the same as the required
 placements or not. If the placements are not the same and- redistribute_inputsis- False, an exception will be raised. Otherwise if- redistribute_inputsis- True, the argument will be first redistributed to the required sharding placements before passing its local tensor to- func.
 The only exception is when required placements are not- Noneand the argument is a- torch.Tensor. In this case, the placements examination
 will be skipped and the argument will be directly passed to- func.
 If- in_placementsis- None, no placements examination will be performed.
 Default: None
- device_mesh (DeviceMesh- , optional)– the device mesh that all the- DTensors are placed on. If not
 specified, this will be inferred from the input- DTensors’ device
 mesh. local_map requires every- DTensors to be placed on the same device mesh. Default: None.
- redistribute_inputs ([bool], optional)– the bool value indicating whether to reshard the input- DTensors when
 their placements are different from the required input placements. If this value is- Falseand some- DTensorinput has a different placement, an exception will be raised. Default: False.
Returns
 A Callable that applies func to each local shard of the input DTensor and returns a DTensor constructed from the return value of func.
Raises
- AssertionError– If the input- DTensoris not placed on the same device
 mesh, or if they are placed on a different device mesh than the- device_mesh
 argument passed in.
- AssertionError– For any non-DTensor output, we require its corresponding
 output placement in- out_placementsbe None. An AssertionError will be raised
 if this is not the case.
- ValueError– If- redistribute_inputs=Falsebut the input- DTensorneeds
 a redistribution according to- in_placements.
Example :
>>
>
def mm_allreduce_forward(device_mesh, W, X):
>>
> partial_sum_tensor = torch.mm(W, X)
>>
> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
>>
>
return reduced_tensor
>>
>
>>
> W = torch.randn(12, 8, requires_grad=False)
>>
> X = torch.randn(8, 16, requires_grad=False)
>>
> Y = torch.mm(W, X)
>>
> row_wise = [Shard(0)] # 在一维网格上的行分片布局
>>
> col_wise = [Shard(1)] # 在一维网格上的列分片布局
>>
>
>>
>
# local_mm_allreduce_forward是封装了DTensor/Tensor转换的函数
>>
> local_mm_allreduce_forward = local_map(
>>
> mm_allreduce_forward,
>>
> out_placements=[Replicate()],
>>
> in_placements=[col_wise, row_wise],
>>
> device_mesh=device_mesh,
>>
>
)
>>
>
>>
> W_dt = distribute_tensor(
... W, device_mesh, (col_wise)
... ) # 列分片的W张量
>>
> X_dt = distribute_tensor(
... X, device_mesh, (row_wise)
... ) # 行分片的X张量
>>
> Y_dt = local_mm_allreduce_forward(
... device_mesh, W_dt, X_dt
... ) # 对DTensors应用local_mm_allreduce_forwardNote: This API is currently experimental and subject to change
torch.distributed.tensor.experimental.register_sharding(op)register_sharding() is an experimental API that allows users to register sharding
 strategies for an operator when the tensor inputs and outputs are DTensor.
 It can be useful when: (1) there doesn’t exist a default sharding strategy for op, e.g. when op is a custom operator that is not supported by DTensor; (2)
 when users would like to overwrite default sharding strategies of existing operators.
Parameters
- op (Union[OpOverload*, List[OpOverload]])– An op or a list of ops to register the customized sharding function.
Returns
 A function decorator which can be used to wrap a function that defines the sharding
 strategy for the operator specified in op. The defined sharding strategy will be registered to DTensor and will override the default sharding strategy if DTensor has already implemented the operator. The customized sharding function takes the same inputs as the original op (except that if an arg is a torch.Tensor , it will be replaced by a tensor-like object that DTensor uses internally). The function should
 return a sequence of 2-tuples, each specifying acceptable output placements and its
 corresponding intput placements.
Example:
>>
> @register_sharding(aten._softmax.default)
>>
>
def custom_softmax_sharding(x, dim, half_to_float):
>>
> softmax_dim = dim if dim >= 0 else dim + x.ndim
>>
> acceptable_shardings = []
>>
>
>>
> all_replicate = ([Replicate()], [Replicate(), None, None])
>>
> acceptable_shardings.append(all_replicate)
>>
>
>>
>
for sharding_dim in range(x.ndim):
>>
>
if sharding_dim != softmax_dim:
>>
> all_sharded = (
>>
>
[Shard(sharding_dim)],
>>
>
[Shard(sharding_dim), None, None],
>>
>
)
>>
> acceptable_shardings.append(all_sharded)
>>
>
>>
>
return acceptable_shardings注意:此 API 目前处于实验阶段,后续可能会发生变化
通用Join上下文管理器
通用Join上下文管理器用于简化不均匀输入的分布式训练。本文档概述了相关类的API:Join、Joinable和JoinHook。如需教程,请参阅使用Join上下文管理器进行不均匀输入的分布式训练。
class torch
.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)该类定义了通用的join上下文管理器,允许在进程加入后调用自定义钩子。
这些钩子应屏蔽未加入进程的集体通信,以防止挂起和错误,并确保算法正确性。有关钩子定义的详细信息,请参阅JoinHook。
警告:上下文管理器要求每个参与的Joinable在其每次迭代的集体通信之前调用notify_join_context()方法以确保正确性。
警告:上下文管理器要求JoinHook对象中的所有process_group属性必须相同。如果存在多个JoinHook对象,则使用第一个对象的device。
进程组和设备信息用于检查未加入的进程,并在启用throw_on_early_termination时通知进程抛出异常,这两者都使用all-reduce操作。
参数
- joinables (List[Joinable ])– 参与的- Joinable列表;将按给定顺序迭代它们的钩子。
- enable ([bool])– 启用不均匀输入检测的标志;设置为- False将禁用上下文管理器的功能,仅当用户确认输入不会不均匀时才应设置(默认值:- True)。
- throw_on_early_termination ([bool])– 控制检测到不均匀输入时是否抛出异常的标志(默认值:- False)。
示例:
>>
>
import os
>>
>
import torch
>>
>
import torch.distributed as dist
>>
>
import torch.multiprocessing as mp
>>
>
import torch.nn.parallel.DistributedDataParallel as DDP
>>
>
import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>
>
from torch.distributed.algorithms.join import Join
>>
>
>
>>
>
# On each spawned worker
>>
>
def worker(rank):
>>
> dist.init_process_group("nccl", rank=rank, world_size=2)
>>
> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>
> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>
>
# Rank 1 gets one more input than rank 0
>>
> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>
>
with Join([model, optim]):
>>
>
for input in inputs:
>>
> loss = model(input).sum()
>>
> loss.backward()
>>
> optim.step()
>>
>
# All ranks reach here without hanging/erroringstatic notify_join_context(joinable)通知连接上下文管理器,调用进程尚未加入。
如果设置了 throw_on_early_termination=True,则会检查是否检测到输入不均衡
 (即是否有进程已提前加入),若存在则抛出异常。
此方法应在 Joinable 对象执行每次迭代的集合通信前调用。
 例如,在 DistributedDataParallel 的前向传播开始时应当调用此方法。
只有传入上下文管理器的第一个 Joinable 对象会执行此方法中的集合通信,
 其余对象调用此方法时不执行实际操作。
参数
- joinable (Joinable)– 调用此方法的- Joinable对象。
返回值
 若当前 joinable 是传入上下文管理器的首个对象,则返回一个异步工作句柄,
 用于通过全减操作通知上下文管理器该进程尚未加入;否则返回 None。
class torch
.distributed.algorithms.Joinable这里定义了一个可加入类的抽象基类。
一个可加入类(继承自 Joinable)需要实现以下内容:
- join_hook()方法,返回一个- JoinHook实例
- join_device()方法,返回设备信息
- join_process_group()方法,返回进程组信息
ABSTRACT PROPERTY join_device: device返回用于执行由 join 上下文管理器所需的集体通信的设备。
ABSTRACT join_hook(**kwargs)为给定的 Joinable 返回一个 JoinHook 实例。
参数
- kwargs (dict)- 包含运行时修改 join hook 行为的关键字参数字典;所有共享相同 join 上下文管理器的- Joinable实例都会收到相同的- kwargs值。
返回类型:JoinHook
ABSTRACT PROPERTY join_process_group: Any返回连接上下文管理器本身所需的集体通信的进程组。
class torch
.distributed.algorithms.JoinHook这里定义了一个连接钩子(join hook),它在连接上下文管理器中提供了两个入口点:
入口点包括:
 1、主钩子(main hook):当存在未连接的进程时会被重复调用
 2、后置钩子(post-hook):当所有进程都完成连接后会被调用一次
要为通用连接上下文管理器实现连接钩子,需要定义一个继承自JoinHook的类,并根据需要重写main_hook()和post_hook()方法。
main_hook()在训练迭代中存在未加入的进程时调用此钩子,以跟踪集体通信。
训练迭代指的是:一次前向传播、反向传播和优化器步骤的过程。
post_hook(is_last_joiner)在所有进程都加入后调用钩子。
该钩子会接收一个额外的 bool 类型参数 is_last_joiner,用于指示当前 rank 是否属于最后一批加入的进程。
参数
- is_last_joiner ([bool])– 当 rank 属于最后一批加入的进程时为- True;否则为- False。
Torch Distributed Elastic
为分布式 PyTorch 提供容错与弹性能力。
快速开始
使用指南
文档
API
高级功能
插件
全分片数据并行 (FullyShardedDataParallel)
class torch
.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)一个用于在数据并行工作节点间分片模块参数的包装器。
该设计灵感来源于Xu等人的论文以及DeepSpeed的ZeRO第三阶段技术。
FullyShardedDataParallel通常简称为FSDP。
要了解FSDP内部实现原理,请参阅FSDP技术说明。
示例:
>>
>
import torch
>>
>
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>
> torch.cuda.set_device(device_id)
>>
> sharded_module = FSDP(my_module)
>>
> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>
> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>
> loss = x.sum()
>>
> loss.backward()
>>
> optim.step()使用FSDP需要先包装你的模块,然后在之后初始化优化器。这是必要的,因为FSDP会改变参数变量。
在设置FSDP时,你需要考虑目标CUDA设备。如果设备有ID(dev_id),你有三个选项:
- 将模块放在该设备上
- 使用torch.cuda.set_device(dev_id)设置设备
- 将dev_id传入device_id构造函数参数
这确保了FSDP实例的计算设备是目标设备。对于选项1和3,FSDP初始化始终在GPU上进行。对于选项2,FSDP初始化发生在模块的当前设备上,可能是CPU。
如果你使用sync_module_states=True标志,需要确保模块在GPU上,或者使用device_id参数指定FSDP在构造函数中将模块移动到的CUDA设备。这是必要的,因为sync_module_states=True需要GPU通信。
FSDP还会负责将输入张量移动到前向方法的GPU计算设备上,因此你不需要手动将它们从CPU移动。
对于use_orig_params=True,ShardingStrategy.SHARD_GRAD_OP会暴露未分片的参数,而不是像ShardingStrategy.FULL_SHARD那样在前向之后的分片参数。如果你想检查梯度,可以使用summon_full_params方法并设置with_grads=True。
使用limit_all_gathers=True时,你可能会在FSDP前向之前看到一个CPU线程没有发出任何内核的间隙。这是有意为之,显示了速率限制器在起作用。以这种方式同步CPU线程可以防止为后续的all-gather操作过度分配内存,实际上不会延迟GPU内核的执行。
出于与自动梯度相关的原因,FSDP在前向和后向计算期间会用torch.Tensor视图替换托管模块的参数。如果你的模块的前向依赖于保存的参数引用而不是每次迭代重新获取引用,那么它将看不到FSDP新创建的视图,自动梯度将无法正常工作。
最后,当使用sharding_strategy=ShardingStrategy.HYBRID_SHARD且分片进程组为节点内、复制进程组为节点间时,设置NCCL_CROSS_NIC=1可以帮助在某些集群设置中提高复制进程组的all-reduce时间。
限制
使用FSDP时有几个需要注意的限制:
- 在使用CPU卸载时,FSDP目前不支持在no_sync()之外进行梯度累积。这是因为FSDP使用新减少的梯度而不是与任何现有梯度累积,这可能导致不正确的结果。
- FSDP不支持运行包含在FSDP实例中的子模块的前向传递。这是因为子模块的参数会被分片,但子模块本身不是FSDP实例,因此其前向传递不会适当地all-gather完整参数。
- 由于FSDP注册后向钩子的方式,它不支持双重后向。
- FSDP在冻结参数时有一些限制。对于use_orig_params=False,每个FSDP实例必须管理全部冻结或全部未冻结的参数。对于use_orig_params=True,FSDP支持混合冻结和未冻结参数,但建议避免这样做以防止高于预期的梯度内存使用。
- 截至PyTorch 1.12,FSDP对共享参数的支持有限。如果你的用例需要增强的共享参数支持,请在此问题中发帖。
- 你应该避免在不使用summon_full_params上下文的情况下在前向和后向之间修改参数,因为这些修改可能不会持久化。
参数
- module (nn.Module)– 这是要用FSDP包装的模块。
- process_group (Optional[Union[ProcessGroup*, Tuple[ProcessGroup*, ProcessGroup]]])– 这是模型分片的进程组,因此也是用于FSDP的all-gather和reduce-scatter集体通信的进程组。如果为- None,则FSDP使用默认进程组。对于混合分片策略如- ShardingStrategy.HYBRID_SHARD,用户可以传入一个进程组元组,分别表示分片和复制的组。如果为- None,则FSDP为用户构建进程组以在节点内分片和在节点间复制。(默认:- None)
- sharding_strategy (Optional[ShardingStrategy])– 这配置分片策略,可能会在内存节省和通信开销之间进行权衡。详情参见- ShardingStrategy。(默认:- FULL_SHARD)
- cpu_offload (Optional[CPUOffload])– 这配置CPU卸载。如果设置为- None,则不进行CPU卸载。详情参见- CPUOffload。(默认:- None)
- auto_wrap_policy (Optional[Union[Callable[[nn.Module,* [bool],* int ],* [bool]], ModuleWrapPolicy*, CustomPolicy]])– 这指定一个策略将FSDP应用于- module的子模块,这对于通信和计算重叠是必要的,从而影响性能。如果为- None,则FSDP仅应用于- module,用户应手动将FSDP应用于父模块(自底向上)。为方便起见,这直接接受- ModuleWrapPolicy,允许用户指定要包装的模块类(例如transformer块)。否则,这应该是一个可调用对象,接受三个参数- module: nn.Module、- recurse: bool和- nonwrapped_numel: int,并返回一个- bool,指定如果- recurse=False是否应对传入的- module应用FSDP,或者如果- recurse=True是否应继续遍历模块的子树。用户可以添加额外的参数到可调用对象。- torch.distributed.fsdp.wrap.py中的- size_based_auto_wrap_policy提供了一个示例可调用对象,如果模块子树中的参数超过100M numel,则应用FSDP。我们建议在应用FSDP后打印模型并根据需要进行调整。
示例:
>>
>
def custom_auto_wrap_policy(
>>
> module: nn.Module, >> recurse: bool, >> nonwrapped_numel: int, >>
# Additional custom arguments
>>
> min_num_params: int = int(1e8), >>
) -bool:
>>
>
return nonwrapped_numel >= min_num_params
>>
>
# Configure a custom `min_num_params`
>>
> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))- backward_prefetch (Optional[BackwardPrefetch])– 该参数用于配置所有-gather操作的显式反向预取。如果设为- None,FSDP将不执行反向预取,导致反向传播过程中没有通信与计算重叠。详情参见- BackwardPrefetch。(默认值:- BACKWARD_PRE)
- mixed_precision (Optional[MixedPrecision])– 该参数用于配置FSDP的原生混合精度。如果设为- None,则不使用混合精度。否则可以设置参数、缓冲区和梯度缩减的数据类型。详情参见- MixedPrecision。(默认值:- None)
- ignored_modules (Optional[Iterable[torch.nn.Module ]])– 该参数指定的模块及其子模块的参数和缓冲区将被当前FSDP实例忽略。直接列在- ignored_modules中的模块不应是- FullyShardedDataParallel实例,且任何已构建的- FullyShardedDataParallel子模块即使嵌套在当前实例下也不会被忽略。该参数可用于:1) 使用- auto_wrap_policy时避免在模块粒度上分片特定参数;2) 当参数分片不由FSDP管理时。(默认值:- None)
- param_init_fn (Optional[Callable[[nn.Module], None]])– 一个可调用对象- Callable[torch.nn.Module] -None,用于指定如何将当前位于meta设备上的模块初始化到实际设备。从v1.12开始,FSDP通过- is_meta检测带有参数或缓冲区的meta设备模块,并执行以下操作:如果指定了- param_init_fn则应用该函数,否则调用- nn.Module.reset_parameters()。两种情况下,实现应仅初始化该模块的参数/缓冲区,而非其子模块的,以避免重复初始化。此外,FSDP还支持通过torchdistX的(https://github.com/pytorch/torchdistX)- deferred_init()API进行延迟初始化——延迟模块会通过调用指定的- param_init_fn或torchdistX默认的- materialize_module()来初始化。如果指定了- param_init_fn,它将应用于所有meta设备模块,因此可能需要根据模块类型进行条件判断。FSDP在参数扁平化和分片之前调用初始化函数。
示例:
>>
> module = MyModule(device="meta")
>>
>
def my_init_fn(module: nn.Module):
>>
>
# E.g. initialize depending on the module type
>>
>
...
>>
> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
>>
>
print(next(fsdp_model.parameters()).device) # current CUDA device
>>
>
# With torchdistX
>>
> module = deferred_init.deferred_init(MyModule, device="cuda")
>>
>
# Will initialize via deferred_init.materialize_module().
>>
> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)- device_id (Optional[Union[int, torch.device]])– 指定FSDP初始化所在的CUDA设备,可以是- int或- torch.device类型,包括模块初始化(如需要)和参数分片过程。当- module位于CPU时指定该参数可提升初始化速度。若已设置默认CUDA设备(例如通过- torch.cuda.set_device),可传入- torch.cuda.current_device。(默认值:- None)
- sync_module_states ([bool])– 若为- True,每个FSDP模块会从rank 0广播模块参数和缓冲区,确保跨rank数据一致(会增加构造函数的通信开销)。这有助于通过- load_state_dict以内存高效方式加载- state_dict检查点。示例用法参见- FullStateDictConfig。(默认值:- False)
- forward_prefetch ([bool])– 若为- True,FSDP会显式在当前前向计算完成前预取下一轮前向传播的all-gather操作。仅适用于CPU密集型工作负载,提前发起all-gather可能提升计算重叠度。由于预取遵循首轮迭代执行顺序,该参数仅适用于静态图模型。(默认值:- False)
- limit_all_gathers ([bool])– 若为- True,FSDP会显式同步CPU线程,确保GPU内存仅被两个连续FSDP实例占用(当前执行计算的实例和预取了all-gather的下一个实例)。若为- False,则允许CPU线程无额外同步地发起all-gather。(默认值:- True)该特性常被称为"速率限制器",仅在内存压力低的CPU密集型场景下可设为- False,此时CPU线程可激进提交所有内核而无需考虑GPU内存占用。
- use_orig_params ([bool])– 设为- True时,FSDP将使用模块的原始参数。通过- nn.Module.named_parameters()暴露原始参数而非内部- FlatParameter,使得优化器基于原始参数运行(支持每个原始参数的独立超参)。FSDP会保留原始参数变量,并在未分片/分片状态间转换其数据(始终分别作为底层未分片/分片- FlatParameter的视图)。当前算法中分片形式始终为1D,会丢失原始张量结构。原始参数的数据可能全部/部分/不存在于当前rank,不存在时其数据表现为空张量。用户不应编写依赖分片形式数据的程序。使用- torch.compile()必须设为- True。设为- False会通过- nn.Module.named_parameters()暴露内部- FlatParameter。(默认值:- False)
- ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]])– 指定不由该FSDP实例管理的参数或模块,意味着这些参数不会被分片且梯度不会跨rank规约。该参数与现有- ignored_modules参数功能统一,未来可能弃用- ignored_modules。为保持向后兼容,同时保留两个参数,但FSDP要求二者只能有一个为非- None。
- device_mesh (Optional[DeviceMesh])– 可作为process_group的替代方案。传入时FSDP会使用底层process_group执行all-gather和reduce-scatter集合通信,因此这两个参数需互斥。对于- ShardingStrategy.HYBRID_SHARD等混合分片策略,可传入2D DeviceMesh替代process_group元组。2D FSDP+TP场景下必须使用device_mesh而非process_group。更多DeviceMesh信息请参阅:
apply(fn)对自身以及每个子模块(通过.children()返回)递归应用fn函数。
典型用途包括初始化模型参数(参见torch.nn.init文档)。
与torch.nn.Module.apply相比,此版本在应用fn前会先收集完整参数。注意不应在另一个summon_full_params上下文中调用该方法。
参数
- fn (Module- -None)– 要应用于每个子模块的函数
返回
 自身
返回类型
 Module
check_is_root()检查此实例是否为根 FSDP 模块。
返回类型:bool
clip_grad_norm_(max_norm, norm_type=2.0)对所有参数的梯度范数进行裁剪。
该范数计算时将所有权重参数的梯度视为单个向量,并原地修改这些梯度值。
参数
- max_norm (float 或 int)– 梯度的最大范数值
- norm_type (float 或 int)– 所用p-范数类型。可设为- 'inf'表示无穷范数
返回
参数的总范数值(视为单个向量)。
返回类型:Tensor
若所有FSDP实例都使用NO_SHARD策略(即梯度未跨rank分片),可直接使用torch.nn.utils.clip_grad_norm_()。
若存在FSDP实例使用分片策略(即非NO_SHARD策略),则应改用本方法而非torch.nn.utils.clip_grad_norm_(),因为本方法能正确处理跨rank分片的梯度。
返回的总范数值将根据PyTorch的类型提升规则,采用所有参数/梯度中"最大"的数据类型。例如:若所有参数/梯度使用低精度类型,则返回范数保持该低精度类型;但只要存在至少一个FP32精度的参数/梯度,返回范数将采用FP32类型。
警告:由于涉及集合通信操作,必须在所有rank上调用本方法。
static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)展平分片的优化器状态字典。
该API与shard_full_optim_state_dict()类似,唯一区别在于输入的sharded_optim_state_dict应来自sharded_optim_state_dict()的返回结果。因此,每个rank上都会执行all-gather调用来收集ShardedTensor。
参数
- sharded_optim_state_dict (Dict[str, Any])- 与未展平参数对应的优化器状态字典,包含分片的优化器状态。
- model ( torch.nn.Module )- 参考- shard_full_optim_state_dict()。
- optim ( torch.optim.Optimizer )- 用于- model参数的优化器。
返回值:参考shard_full_optim_state_dict()。
返回类型:dict[str, Any]
forward(*args, **kwargs)对封装模块执行前向传播,同时插入 FSDP 特有的前向分片与后向分片逻辑。
返回类型:Any
*static* fsdp_modules(module, root_only=False)返回所有嵌套的FSDP实例。
这可能包含module本身,且当root_only=True时仅包含FSDP根模块。
参数
- module ( torch.nn.Module )– 根模块,可能是也可能不是一个- FSDP模块。
- root_only ([bool])– 是否仅返回FSDP根模块。(默认值:- False)
返回
嵌套在输入module中的FSDP模块。
返回类型:List [FullyShardedDataParallel]
static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)返回完整的优化器状态字典。
该方法会在 rank 0 上整合完整的优化器状态,并以 dict 形式返回,遵循 torch.optim.Optimizer.state_dict() 的规范,即包含 "state" 和 "param_groups" 键。model 中包含的 FSDP 模块中的扁平化参数会被映射回其原始的非扁平化参数。
由于使用了集体通信操作,此方法需要在所有 rank 上调用。但如果 rank0_only=True,则仅在 rank 0 上填充状态字典,其他所有 rank 返回空字典。
与 torch.optim.Optimizer.state_dict() 不同,本方法使用完整参数名作为键(而非参数 ID)。
与 torch.optim.Optimizer.state_dict() 类似,优化器状态字典中包含的张量不会被克隆,因此可能存在别名意外。建议立即保存返回的优化器状态字典(例如使用 torch.save())以获得最佳实践。
参数
- model (torch.nn.Module)– 根模块(可能是也可能不是- FullyShardedDataParallel实例),其参数已传入优化器- optim。
- optim (torch.optim.Optimizer)– 用于- model参数的优化器。
- optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]])– 传入优化器- optim的输入,表示参数组的- list或可迭代参数;如果为- None,则该方法假定输入为- model.parameters()。此参数已弃用,无需再传递。(默认值:- None)
- rank0_only ([bool])– 如果为- True,仅在 rank 0 上保存填充的字典;如果为- False,在所有 rank 上保存。(默认值:- True)
- group (dist.ProcessGroup)– 模型的进程组,如果使用默认进程组则为- None。(默认值:- None)
返回值:一个 dict,包含 model 原始非扁平化参数的优化器状态,并遵循 torch.optim.Optimizer.state_dict() 规范包含 “state” 和 “param_groups” 键。如果 rank0_only=True,则非零 rank 返回空字典。
返回类型:Dict[str, Any]
static get_state_dict_type(module)获取以 module 为根的 FSDP 模块的 state_dict_type 及其对应配置。
目标模块不必是 FSDP 模块。
返回值:返回一个 StateDictSettings 对象,包含当前设置的 state_dict_type 以及 state_dict / optim_state_dict 配置。
异常
- 如果不同 FSDP 子模块的 StateDictSettings 不一致,抛出 AssertionError
返回类型:StateDictSettings
property module: Module返回被包装的模块。
named_buffers(*args, **kwargs)返回一个遍历模块缓冲区的迭代器,同时生成缓冲区的名称和缓冲区本身。
在 summon_full_params() 上下文管理器内部时,会拦截缓冲区名称并移除所有特定于FSDP的扁平化缓冲区前缀。
返回类型为 Iterator [tuple [str, torch.Tensor]]
named_parameters(*args, **kwargs)返回一个遍历模块参数的迭代器,同时生成参数名称和参数本身。
在 summon_full_params() 上下文管理器内部时,会拦截参数名称并移除所有特定于FSDP的扁平化参数前缀。
返回类型为 Iterator [tuple [str, [torch.nn.parameter.Parameter]]
no_sync()禁用跨FSDP实例的梯度同步。
在此上下文中,梯度将累积在模块变量中,这些梯度会在退出上下文后的首次前向-反向传播过程中同步。此功能应仅用于根FSDP实例,并将递归应用于所有子FSDP实例。
注意:这可能导致更高的内存使用量,因为FSDP会累积完整的模型梯度(而非梯度分片),直到最终同步完成。
注意:与CPU卸载功能同时使用时,在上下文管理器内部梯度不会被卸载到CPU。相反,它们只会在最终同步后立即被卸载。
返回类型:生成器
static optim_state_dict(model, optim, optim_state_dict=None, group=None)转换分片模型对应的优化器状态字典。
给定的状态字典可转换为以下三种类型之一:
- 完整优化器状态字典 2) 分片优化器状态字典 3) 本地优化器状态字典
对于完整优化器状态字典,所有状态均未展平且未分片。可通过 state_dict_type() 指定仅限 Rank0 和仅限 CPU 以避免内存溢出。
对于分片优化器状态字典,所有状态均未展平但已分片。可通过 state_dict_type() 指定仅限 CPU 以进一步节省内存。
对于本地状态字典,不会执行任何转换。但状态会从 nn.Tensor 转换为 ShardedTensor 以表示其分片特性(当前尚未支持此功能)。
示例:
>>
>
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>
>
from torch.distributed.fsdp import StateDictType
>>
>
from torch.distributed.fsdp import FullStateDictConfig
>>
>
from torch.distributed.fsdp import FullOptimStateDictConfig
>>
>
# Save a checkpoint
>>
> model, optim = ...
>>
> FSDP.set_state_dict_type(
>>
> model, >> StateDictType.FULL_STATE_DICT, >> FullStateDictConfig(rank0_only=False), >> FullOptimStateDictConfig(rank0_only=False), >>
)
>>
> state_dict = model.state_dict()
>>
> optim_state_dict = FSDP.optim_state_dict(model, optim)
>>
> save_a_checkpoint(state_dict, optim_state_dict)
>>
>
# Load a checkpoint
>>
> model, optim = ...
>>
> state_dict, optim_state_dict = load_a_checkpoint()
>>
> FSDP.set_state_dict_type(
>>
> model, >> StateDictType.FULL_STATE_DICT, >> FullStateDictConfig(rank0_only=False), >> FullOptimStateDictConfig(rank0_only=False), >>
)
>>
> model.load_state_dict(state_dict)
>>
> optim_state_dict = FSDP.optim_state_dict_to_load(
>>
> model, optim, optim_state_dict
>>
>
)
>>
> optim.load_state_dict(optim_state_dict)参数
- model ( torch.nn.Module )– 根模块(可能是也可能不是- FullyShardedDataParallel实例),其参数已传入优化器- optim。
- optim ( torch.optim.Optimizer )– 用于- model参数的优化器。
- optim_state_dict (Dict[str, Any])– 需要转换的目标优化器状态字典。若值为None,将使用optim.state_dict()。(默认值:- None)
- group (dist.ProcessGroup)– 模型参数分片所在的进程组,若使用默认进程组则为- None。(默认值:- None)
返回值:一个包含model优化器状态的dict。优化器状态的分片基于state_dict_type。
返回类型:Dict[str, Any]
static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)将优化器状态字典转换为可加载到与FSDP模型关联的优化器中的格式。
给定一个通过 optim_state_dict() 转换得到的 optim_state_dict,该方法会将其转换为扁平化的优化器状态字典,该字典可加载到 model 的优化器 optim 中。注意:model 必须是通过 FullyShardedDataParallel 进行分片的模型。
>>
>
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>
>
from torch.distributed.fsdp import StateDictType
>>
>
from torch.distributed.fsdp import FullStateDictConfig
>>
>
from torch.distributed.fsdp import FullOptimStateDictConfig
>>
>
# Save a checkpoint
>>
> model, optim = ...
>>
> FSDP.set_state_dict_type(
>>
> model, >> StateDictType.FULL_STATE_DICT, >> FullStateDictConfig(rank0_only=False), >> FullOptimStateDictConfig(rank0_only=False), >>
)
>>
> state_dict = model.state_dict()
>>
> original_osd = optim.state_dict()
>>
> optim_state_dict = FSDP.optim_state_dict(
>>
> model, >> optim, >> optim_state_dict=original_osd
>>
>
)
>>
> save_a_checkpoint(state_dict, optim_state_dict)
>>
>
# Load a checkpoint
>>
> model, optim = ...
>>
> state_dict, optim_state_dict = load_a_checkpoint()
>>
> FSDP.set_state_dict_type(
>>
> model, >> StateDictType.FULL_STATE_DICT, >> FullStateDictConfig(rank0_only=False), >> FullOptimStateDictConfig(rank0_only=False), >>
)
>>
> model.load_state_dict(state_dict)
>>
> optim_state_dict = FSDP.optim_state_dict_to_load(
>>
> model, optim, optim_state_dict
>>
>
)
>>
> optim.load_state_dict(optim_state_dict)参数
- model (torch.nn.Module)– 根模块(可能是也可能不是- FullyShardedDataParallel实例),其参数已传入优化器- optim。
- optim (torch.optim.Optimizer)– 用于- model参数的优化器。
- optim_state_dict (Dict[str, Any])– 待加载的优化器状态字典。
- is_named_optimizer ([bool])– 该优化器是否为NamedOptimizer或KeyedOptimizer。仅当- optim是TorchRec的KeyedOptimizer或torch.distributed的NamedOptimizer时设为True。
- load_directly ([bool])– 若设为True,本API将在返回结果前自动调用optim.load_state_dict(result);否则用户需自行调用- optim.load_state_dict()(默认值:- False)。
- group (dist.ProcessGroup)– 模型参数分片所在的进程组,若使用默认进程组则为- None(默认值:- None)。
返回类型:dict[str, Any]
register_comm_hook(state, hook)注册一个通信钩子。
该功能是一项增强,为用户提供了一个灵活的钩子,可以指定FSDP如何在多个工作节点间聚合梯度。
这个钩子可用于实现多种算法,例如GossipGrad和梯度压缩,这些算法在使用FullyShardedDataParallel训练时涉及不同的参数同步通信策略。
警告:FSDP通信钩子必须在初始前向传播运行前注册,且只能注册一次。
参数
- state ( object )- 传递给钩子以在训练过程中维护任何状态信息。
示例包括梯度压缩中的误差反馈、GossipGrad中下一次通信的对等节点等。
该状态由每个工作节点本地存储,并由该工作节点上的所有梯度张量共享。
- hook (Callable)- 可调用对象,具有以下签名之一:
- hook: Callable[torch.Tensor] -None:
该函数接收一个Python张量,表示与该FSDP单元包装的模型(未被其他FSDP子单元包装的部分)对应的所有变量的完整、展平、未分片的梯度。
然后执行所有必要的处理并返回None;
- hook: Callable[torch.Tensor, torch.Tensor] -None:
该函数接收两个Python张量,第一个表示与该FSDP单元包装的模型(未被其他FSDP子单元包装的部分)对应的所有变量的完整、展平、未分片的梯度。第二个表示预分配大小的张量,用于存储归约后的分片梯度块。
在这两种情况下,可调用对象都会执行所有必要的处理并返回None。
签名1的可调用对象预期处理NO_SHARD情况下的梯度通信。
签名2的可调用对象预期处理分片情况下的梯度通信。
static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)将优化器状态字典 optim_state_dict 的键类型重新映射为 optim_state_key_type。
该功能可用于实现带有 FSDP 实例的模型与普通模型之间优化器状态字典的兼容性。
若要将 FSDP 完整优化器状态字典(即来自 full_optim_state_dict() 的字典)重新映射为使用参数 ID 键,并使其可加载至未封装模型:
>>
> wrapped_model, wrapped_optim = ...
>>
> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
>>
> nonwrapped_model, nonwrapped_optim = ...
>>
> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
>>
> nonwrapped_optim.load_state_dict(rekeyed_osd)要将普通优化器状态字典(来自未封装模型)重新映射为可加载到封装模型中的格式:
>>
> nonwrapped_model, nonwrapped_optim = ...
>>
> osd = nonwrapped_optim.state_dict()
>>
> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
>>
> wrapped_model, wrapped_optim = ...
>>
> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
>>
> wrapped_optim.load_state_dict(sharded_osd)返回
优化器状态字典,使用optim_state_key_type指定的参数键重新映射键名。
返回类型:Dict[str, Any]
static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)将完整的优化器状态字典从 rank 0 分发到所有其他 ranks。
返回每个 rank 上的分片优化器状态字典。
返回值与 shard_full_optim_state_dict() 相同,且在 rank 0 上,
 第一个参数应为 full_optim_state_dict() 的返回值。
示例:
>>
>
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>
> model, optim = ...
>>
> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0
>>
>
# Define new model with possibly different world size
>>
> new_model, new_optim, new_group = ...
>>
> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>
> new_optim.load_state_dict(sharded_osd)注意:shard_full_optim_state_dict() 和 scatter_full_optim_state_dict() 均可用于获取待加载的分片优化器状态字典。假设完整优化器状态字典存储在CPU内存中:
- 前者要求每个进程在CPU内存中保存完整字典,各进程独立进行分片且无需通信
- 后者仅要求rank 0在CPU内存中保存完整字典,由rank 0将每个分片移至GPU内存(用于NCCL)并通过通信分发到对应进程
因此,前者总CPU内存开销更高,而后者通信开销更大。
参数说明
- full_optim_state_dict (Optional[Dict[str, Any]])- 对应未展平参数的优化器状态字典,在rank 0上保存完整非分片状态;非0 rank会忽略该参数
- model (torch.nn.Module)- 根模块(可能是也可能不是- FullyShardedDataParallel实例),其参数与- full_optim_state_dict中的优化器状态对应
- optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]])- 传入优化器的输入,可以是参数组的- list或参数的可迭代对象;若为- None则默认使用- model.parameters()。此参数已弃用,无需再传递(默认值:- None)
- optim (Optional[torch.optim.Optimizer])- 将加载本方法返回状态字典的优化器。推荐使用此参数替代- optim_input(默认值:- None)
- group (dist.ProcessGroup)- 模型使用的进程组,若为- None则使用默认进程组(默认值:- None)
返回值:返回重构后的完整优化器状态字典,其中:
- 参数映射为展平后的形式(而非原始未展平参数)
- 仅包含当前rank对应的优化器状态部分
返回类型:Dict[str, Any]
static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)设置目标模块所有子级FSDP模块的state_dict_type。
同时支持(可选)配置模型和优化器的状态字典。
目标模块不必是FSDP模块。如果目标模块本身是FSDP模块,其state_dict_type也会被修改。
注意:此API应仅对顶层(根)模块调用。
注意:当根FSDP模块被其他nn.Module包裹时,此API允许用户透明地使用常规state_dict接口来保存模型检查点。例如以下场景:确保对非FSDP实例调用state_dict方法,同时对FSDP实例则转为分片状态字典实现:
示例:
>>
> model = DDP(FSDP(...))
>>
> FSDP.set_state_dict_type(
>>
> model, >> StateDictType.SHARDED_STATE_DICT, >> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>
)
>>
> param_state_dict = model.state_dict()
>>
> optim_state_dict = FSDP.optim_state_dict(model, optim)参数
- module ( torch.nn.Module )– 根模块。
- state_dict_type (StateDictType)– 要设置的期望- state_dict_type。
- state_dict_config (Optional[StateDictConfig])– 目标- state_dict_type的配置。
- optim_state_dict_config (Optional[OptimStateDictConfig])– 优化器状态字典的配置。
返回值:返回一个包含模块先前状态字典类型和配置的 StateDictSettings 对象。
返回类型:StateDictSettings
static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)切分完整优化器状态字典。
将 full_optim_state_dict 中的状态从非展平参数重新映射为展平参数,并限制为仅包含当前秩(rank)对应的优化器状态部分。
第一个参数应为 full_optim_state_dict() 的返回值。
示例:
>>
>
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>
> model, optim = ...
>>
> full_osd = FSDP.full_optim_state_dict(model, optim)
>>
> torch.save(full_osd, PATH)
>>
>
# Define new model with possibly different world size
>>
> new_model, new_optim = ...
>>
> full_osd = torch.load(PATH)
>>
> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
>>
> new_optim.load_state_dict(sharded_osd)注意:shard_full_optim_state_dict() 和 scatter_full_optim_state_dict() 均可用于获取待加载的分片优化器状态字典。假设完整优化器状态字典驻留在CPU内存中:
- 前者要求每个rank在CPU内存中保存完整字典,各rank独立进行分片且无需通信
- 后者仅要求rank 0在CPU内存中保存完整字典,由rank 0将各分片移至GPU内存(用于NCCL)并通过通信分发到对应rank
因此,前者总CPU内存开销更高,而后者的通信开销更大。
参数说明
- full_optim_state_dict (Dict[str, Any])- 对应未展平参数的优化器状态字典,包含完整未分片的优化器状态
- model ( torch.nn.Module )- 根模块(可能是也可能不是- FullyShardedDataParallel实例),其参数与- full_optim_state_dict中的优化器状态相对应
- optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]])- 传递给优化器的输入,可以是参数组的- list或可迭代参数集合;若为- None则默认使用- model.parameters()。此参数已弃用,无需再传递(默认值:- None)
- optim (Optional[torch.optim.Optimizer ])- 将加载本方法返回状态字典的优化器。推荐优先使用此参数而非- optim_input(默认值:- None)
返回值:返回重构后的完整优化器状态字典,该字典:
 1、已从未展平参数映射为展平参数
 2、仅包含当前rank对应的优化器状态部分
返回类型:Dict[str, Any]
static sharded_optim_state_dict(model, optim, group=None)返回优化器状态字典的分片形式。
该API与full_optim_state_dict()类似,但会将所有非零维状态分块为ShardedTensor以节省内存。
注意:只有当模型state_dict是通过上下文管理器with state_dict_type(SHARDED_STATE_DICT):导出时,才应使用此API。
具体用法请参考full_optim_state_dict()。
警告:返回的状态字典包含ShardedTensor,不能直接被常规的optim.load_state_dict使用。
返回类型:dict[str, Any]
static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)设置目标模块下所有FSDP子模块的state_dict_type。
此上下文管理器功能与set_state_dict_type()相同。详情请参阅set_state_dict_type()的文档说明。
示例:
>>
> model = DDP(FSDP(...))
>>
>
with FSDP.state_dict_type(
>>
> model, >> StateDictType.SHARDED_STATE_DICT, >>
):
>>
> checkpoint = model.state_dict()参数
- module ( torch.nn.Module )– 根模块。
- state_dict_type (StateDictType)– 要设置的期望- state_dict_type。
- state_dict_config (Optional[StateDictConfig ])– 目标- state_dict_type对应的模型- state_dict配置。
- optim_state_dict_config (Optional[OptimStateDictConfig ])– 目标- state_dict_type对应的优化器- state_dict配置。
返回类型:生成器
static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)通过此上下文管理器暴露FSDP实例的完整参数。
在模型完成前向/反向传播后,可用于获取参数进行额外处理或检查。它可以接受非FSDP模块,并根据recurse参数递归地为所有包含的FSDP模块及其子模块召唤完整参数。
注意:可在内部FSDP上使用。
注意:不可在前向或反向传播过程中使用,也不能在该上下文中启动前向/反向传播。
注意:上下文管理器退出后,参数将恢复为本地分片状态,存储行为与前向传播相同。
注意:完整参数可被修改,但只有对应本地参数分片的部分会在上下文退出后保留(除非设置writeback=False,此时修改会被丢弃)。当FSDP不对参数分片时(当前仅world_size == 1或NO_SHARD配置),无论writeback如何设置,修改都会被保留。
注意:此方法适用于非FSDP模块(可能包含多个独立FSDP单元)。此时给定参数将应用于所有包含的FSDP单元。
警告:当前不支持rank0_only=True与writeback=True同时使用,会触发错误。因为上下文中各rank的模型参数形状不同,退出时写入会导致跨rank不一致。
警告:offload_to_cpu与rank0_only=False组合会导致完整参数被冗余复制到同一机器的CPU内存中,可能引发CPU OOM风险。建议配合rank0_only=True使用offload_to_cpu。
参数说明
- recurse([bool], 可选) – 是否递归召唤嵌套FSDP实例的所有参数(默认:True)
- writeback([bool], 可选) – 若为- False,上下文退出时丢弃参数修改;禁用此选项可略微提升效率(默认:True)
- rank0_only([bool], 可选) – 若为- True,仅全局rank 0会物化完整参数,其他rank保持分片参数。注意- rank0_only=True与- writeback=True的组合不被支持,因上下文中各rank参数形状不同,退出时写入会导致不一致
- offload_to_cpu([bool], 可选) – 若为- True,完整参数将卸载到CPU。当前仅分片参数会触发卸载(- world_size=1或- NO_SHARD配置除外)。建议配合- rank0_only=True使用以避免重复卸载到同一CPU内存
- with_grads([bool], 可选) – 若为- True,梯度会随参数一起解除分片。当前仅当FSDP构造器传入- use_orig_params=True且本方法设置- offload_to_cpu=False时支持(默认:- False)
返回类型
 生成器
class torch
.distributed.fsdp.BackwardPrefetch(value)该配置启用了显式的反向预取功能,通过在后向传递中实现通信与计算的重叠来提升吞吐量,但会略微增加内存使用量。
- BACKWARD_PRE:实现最大程度的重叠,但内存使用量也最高。该模式会在当前参数组的梯度计算之前预取下一组参数。这使得下一次全收集操作与当前梯度计算能够重叠执行,内存峰值时会同时保留当前参数组、下一组参数以及当前梯度数据。
- BACKWARD_POST:实现较少重叠,但内存需求更低。该模式在当前参数组的梯度计算完成后才预取下一组参数。这使得当前规约散射操作与下一组梯度计算能够重叠执行,并在为下一组参数分配内存前释放当前参数组,内存峰值时仅保留下一组参数和当前梯度数据。
- FSDP的backward_prefetch参数支持设为None以完全禁用反向预取。该模式无任何重叠效果且不会增加内存开销。通常不建议采用此设置,因为它可能显著降低吞吐性能。
技术背景说明:对于使用NCCL后端的单个进程组,所有集合操作(即使来自不同流)都会争用同一设备上的NCCL流,这意味着集合操作的发起顺序将直接影响重叠效果。上述两种反向预取值对应不同的操作发起顺序。
class torch
.distributed.fsdp.ShardingStrategy(value)这里指定了FullyShardedDataParallel用于分布式训练的分片策略。
- FULL_SHARD:参数、梯度和优化器状态均进行分片。
对于参数,该策略在前向计算前执行解分片(通过all-gather操作),前向计算后重新分片,反向计算前再次解分片,反向计算后重新分片。对于梯度,在反向计算后通过reduce-scatter操作进行同步和分片。分片的优化器状态由每个rank本地更新。
- SHARD_GRAD_OP:计算过程中梯度和优化器状态保持分片,此外参数在计算外也保持分片。
对于参数,该策略在前向计算前执行解分片,前向计算后不重新分片,仅在反向计算后重新分片。分片的优化器状态由每个rank本地更新。在no_sync()上下文中,反向计算后参数不会重新分片。
- NO_SHARD:参数、梯度和优化器状态不进行分片,而是像PyTorch的- DistributedDataParallelAPI那样跨rank复制。对于梯度,该策略在反向计算后通过all-reduce操作进行同步。未分片的优化器状态由每个rank本地更新。
- HYBRID_SHARD:在节点内应用- FULL_SHARD策略,同时在节点间复制参数。由于昂贵的all-gather和reduce-scatter操作仅在节点内执行,这可以减少通信量,对中等规模模型可能更具性能优势。
- _HYBRID_SHARD_ZERO2:在节点内应用- SHARD_GRAD_OP策略,同时在节点间复制参数。与- HYBRID_SHARD类似,但由于前向计算后未释放解分片的参数,节省了反向计算前的all-gather操作,可能提供更高的吞吐量。
class torch
.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<
class 'torch.nn.modules.batchnorm._BatchNorm'
>
, ))此配置用于FSDP原生的混合精度训练。
变量说明
- param_dtype (Optional[torch.dtype])- 指定前向传播和反向传播期间模型参数的数据类型,从而决定前向和反向计算的数据类型。在前向和反向之外,分片参数保持全精度(例如用于优化器步骤),而模型检查点时参数始终以全精度保存。(默认:- None)
- reduce_dtype (Optional[torch.dtype])- 指定梯度归约(即reduce-scatter或all-reduce)的数据类型。若为- None但- param_dtype不为- None,则采用- param_dtype值,仍以低精度运行梯度归约。允许与- param_dtype不同,例如强制梯度归约以全精度运行。(默认:- None)
- buffer_dtype (Optional[torch.dtype])- 指定缓冲区的数据类型。FSDP不对缓冲区进行分片,而是在首次前向传播时将其转换为- buffer_dtype并保持该类型。模型检查点时,除- LOCAL_STATE_DICT外缓冲区均以全精度保存。(默认:- None)
- keep_low_precision_grads ([bool])- 若为- False,FSDP在反向传播后将梯度提升至全精度以备优化器步骤使用;若为- True,则保持梯度为归约时的数据类型,可节省支持低精度运行的自定义优化器的内存。(默认:- False)
- cast_forward_inputs ([bool])- 若为- True,FSDP模块将其前向传播的args和kwargs转换为- param_dtype,确保参数与输入数据类型匹配以满足多数运算要求。当仅对部分FSDP模块应用混合精度时可能需要设为- True,此时混合精度子模块需重新转换输入。(默认:- False)
- cast_root_forward_inputs ([bool])- 若为- True,根FSDP模块会覆盖- cast_forward_inputs值,将其前向传播的args和kwargs转换为- param_dtype。非根FSDP模块不受影响。(默认:- True)
- _module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]])- 指定使用- auto_wrap_policy时忽略混合精度的模块类:这些类的模块将单独应用FSDP且禁用混合精度(导致最终FSDP构造偏离指定策略)。未指定- auto_wrap_policy时此参数无效。该API为实验性质可能变更。(默认:- (_BatchNorm,))
注意事项
此API为实验性质,可能发生变化。
仅浮点张量会被转换为指定数据类型。
在summon_full_params中,参数强制转为全精度,但缓冲区不受影响。
即使输入为float16或bfloat16等低精度,层归一化和批归一化仍以float32累加。仅为这些归一化模块禁用混合精度意味着仿射参数保持float32,但会导致额外的all-gather和reduce-scatter操作,可能降低效率。若任务允许,建议仍对这些模块应用混合精度。
默认情况下,若用户传入包含_BatchNorm模块的模型并指定auto_wrap_policy,批归一化模块将单独应用FSDP且禁用混合精度。详见_module_classes_to_ignore参数。
MixedPrecision默认设置cast_root_forward_inputs=True且cast_forward_inputs=False。根FSDP实例的cast_root_forward_inputs优先于cast_forward_inputs,非根实例的cast_root_forward_inputs值被忽略。典型场景下(所有FSDP实例具有相同MixedPrecision配置且仅需在模型前向开始时转换输入至param_dtype),默认设置已足够。
对于具有不同MixedPrecision配置的嵌套FSDP实例,建议通过单独设置cast_forward_inputs来配置各实例前向传播前的输入转换。此时由于转换发生在各FSDP实例前向之前,父FSDP实例应使其非FSDP子模块在FSDP子模块之前运行,避免因不同MixedPrecision配置导致激活数据类型变化。
示例:
>>
> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
>>
> model[1] = FSDP(
>>
> model[1], >> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>
)
>>
> model = FSDP(
>>
> model, >> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>
)上面的示例展示了正常运行的情况。反之,如果将 model[1] 替换为 model[0],即让使用不同 MixedPrecision 的子模块先执行前向计算,那么 model[1] 就会错误地接收到 float16 类型的激活值,而非预期的 bfloat16 类型。
class torch
.distributed.fsdp.CPUOffload(offload_params=False)此配置用于启用 CPU 卸载功能。
变量说明
- offload_params ([bool])– 指定是否在参数不参与计算时将其卸载到 CPU。若设为- True,则梯度也会被卸载至 CPU,这意味着优化器步骤将在 CPU 上执行。
class torch
.distributed.fsdp.StateDictConfig(offload_to_cpu=False)StateDictConfig 是所有 state_dict 配置类的基类。用户需要实例化其子类(例如 FullStateDictConfig)来配置 FSDP 所支持的对应 state_dict 类型的相关设置。
变量说明
- offload_to_cpu ([bool])– 若设为- True,FSDP 会将状态字典的值卸载到 CPU;若设为- False,则保留在 GPU 上。(默认值:- False)
class torch
.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)FullStateDictConfig 是一个配置类,专为配合 StateDictType.FULL_STATE_DICT 使用而设计。我们建议在保存完整状态字典时,同时启用 offload_to_cpu=True 和 rank0_only=True 参数,以分别节省 GPU 内存和 CPU 内存。该配置类需通过 state_dict_type() 上下文管理器使用,示例如下:
>>
>
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>
> fsdp = FSDP(model, auto_wrap_policy=...)
>>
> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>
>
with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>
> state = fsdp.state_dict()
>>
>
# `state` will be empty on non rank 0 and contain CPU tensors on rank 0、>># To reload checkpoint for inference, finetuning, transfer learning, etc:
>>
> model = model_fn() # Initialize model in preparation for wrapping with FSDP
>>
>
if dist.get_rank() == 0:
>>
>
# Load checkpoint only on rank 0 to avoid memory redundancy
>>
> state_dict = torch.load("my_checkpoint.pt")
>>
> model.load_state_dict(state_dict)
>>
>
# All ranks initialize FSDP module as usual. `sync_module_states` argument
>>
>
# communicates loaded checkpoint states from rank 0 to rest of the world.
>>
> fsdp = FSDP(
... model,
... device_id=torch.cuda.current_device(),
... auto_wrap_policy=...,
... sync_module_states=True,
... )
>>
>
# After this point, all ranks have FSDP model with loaded checkpoint.变量
- rank0_only ([bool])– 如果设为- True,则仅 rank 0 进程保存完整的状态字典,非零 rank 进程保存空字典。如果设为- False,则所有 rank 进程都会保存完整的状态字典。(默认值:- False)
class torch
.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)ShardedStateDictConfig 是一个配置类,专为与 StateDictType.SHARDED_STATE_DICT 配合使用而设计。
变量说明
- _use_dtensor ([bool])– 若设为- True,FSDP 会将状态字典值保存为- DTensor;若设为- False,则保存为- ShardedTensor。(默认值:- False)
警告:_use_dtensor 是 ShardedStateDictConfig 的私有字段,FSDP 通过该字段决定状态字典值的类型。用户不应手动修改 _use_dtensor。
class torch
.distributed.fsdp.LocalStateDictConfig(offload_to_cpu: bool = False)class torch
.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)OptimStateDictConfig 是所有 optim_state_dict 配置类的基类。用户应实例化子类(例如 FullOptimStateDictConfig)来配置 FSDP 支持的对应 optim_state_dict 类型的设置。
变量说明
- offload_to_cpu ([bool])– 若设为- True,FSDP 会将状态字典的张量值卸载到 CPU;若设为- False,则保留在原始设备上(除非启用了参数 CPU 卸载功能,否则原始设备为 GPU)。(默认值:- True)
class torch
.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)变量
- rank0_only ([bool])– 如果设为- True,则仅rank 0会保存完整的状态字典,非零rank保存空字典。如果设为- False,则所有rank都会保存完整的状态字典。(默认值:- False)
class torch
.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)ShardedOptimStateDictConfig 是一个配置类,专为与 StateDictType.SHARDED_STATE_DICT 配合使用而设计。
变量说明
- _use_dtensor ([bool])– 若设为- True,FSDP 会将状态字典的值保存为- DTensor;若设为- False,则保存为- ShardedTensor。(默认值:- False)
警告_use_dtensor 是 ShardedOptimStateDictConfig 的私有字段,FSDP 通过它来决定状态字典值的类型。用户不应手动修改此字段。
class torch
.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu: bool = False)class torch
.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config:torch.distributed.fsdp.api.StateDictConfig )2025-08-20(三)
 
                    
                     
                    
                 
                    
                 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号