PyTorch DTensor解释

Pytorch DTensor

我们引入分布式tensor原语,可以更容易地使用SPMD(single program multi devices)模式来编写分布式计算。这些原语可以表达出 shard 和 replicate 的概念。一个例子如下:

# run command: uv run torchrun --standalone --nnodes=1 --nproc-per-node=4 example.py

import os
import torch

from torch.distributed.tensor import init_device_mesh, Shard, distribute_tensor

torch.random.manual_seed(41) # let each device owns the same tensor in cpu.

mesh = init_device_mesh("cuda", (int(os.environ["WORLD_SIZE"]),))

big_tensor = torch.randn(100000, 88)

my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(dim=0)])

idx = mesh.get_rank()

print(f"cur rank: {idx}\n {my_dtensor}")

process_group = mesh.get_group()

output = [torch.zeros(100000 // 4, 88).cuda() for _ in range(4)]
process_group.allgather(output, my_dtensor.to_local()).wait()

output = [ o.cpu() for o in output]

res = torch.cat(output)

if torch.allclose(res, big_tensor):
    print("Success!!!")
else:
    print(res, big_tensor)
    print("Exist bug!!!")

在这个例子中,我们初始化在每个进程初始化一个很大的tensor,这里我们固定种子是为了验证allgather操作的正确性,可以使得这个大tensor都是相同的;我们可以从mesh中得到当前进程的rank,根据rank我们获得部分的local tensor;注意我们需要知道rank为0的进程负责划分这个大tensor;之后我们可以使用集合通信原语得到整个的大tensor;最后可以验证我们的划分结果是正确的。

DTensor的优势

  • 提供一个统一的方式来 save 或 load state_dict
  • 可以在eager模式开启tp,更加灵活
  • 作为 SPMD 编程模型的起点

PyTorch DTensor

API

下面的例子演示一些基础的DTensor操作:

  1. 如何直接构造一个DTensor,来表示不同类型的sharding, replication, sharding + replication
  2. 如何从local的torch.Tensor创建DTensor
  3. 如何将已有的DTensorreshard 成 修改放置策略的 DTensor.
# uv run torchrun --standalone --nnodes=1 --nproc-per-node=4 dtensor_example.py
import torch
from torch.distributed.tensor import DTensor, Shard, Replicate, distribute_tensor, distribute_module, init_device_mesh

# construct a device mesh with available devices (multi-host or single host)
device_mesh = init_device_mesh("cuda", (4,))
# if we want to do row-wise sharding
rowwise_placement=[Shard(0)]
# if we want to do col-wise sharding
colwise_placement=[Shard(1)]

big_tensor = torch.randn(888, 12)
# distributed tensor returned will be sharded across the dimension specified in placements
rowwise_tensor = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=rowwise_placement)

# if we want to do replication across a certain device list
replica_placement = [Replicate()]
# distributed tensor will be replicated to all four GPUs.
replica_tensor = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=replica_placement)

# if we want to distributed a tensor with both replication and sharding
device_mesh = init_device_mesh("cuda", (2, 2))
# replicate across the first dimension of device mesh, then sharding on the second dimension of device mesh
spec=[Replicate(), Shard(0)]
partial_replica = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=spec)

# create a DistributedTensor that shards on dim 0, from a local torch.Tensor
local_tensor = torch.randn((8, 8), requires_grad=True)
rowwise_tensor = DTensor.from_local(local_tensor, device_mesh, rowwise_placement)

# reshard the current row-wise tensor to a colwise tensor or replicate tensor
colwise_tensor = rowwise_tensor.redistribute(device_mesh, colwise_placement)
replica_tensor = colwise_tensor.redistribute(device_mesh, replica_placement)

注意上面的API面向更底层的tensor,我们可以使用面向Module的更高层API

import torch.nn as nn
from torch.distributed.tensor import Shard, distribute_tensor, distribute_module, init_device_mesh

class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(8, 8)
        self.fc2 = nn.Linear(8, 8)
        self.relu = nn.ReLU()

    def forward(self, input):
        return self.relu(self.fc1(input) + self.fc2(input))

mesh = init_device_mesh("cuda", (4,))

def shard_params(mod_name, mod, mesh):
    col_linear_placement = [Shard(0)]
    # shard fc1 and fc2
    if isinstance(mod, nn.Linear):
        for name, param in mod.named_parameters():
            dist_param = nn.Parameter(
                distribute_tensor(param, mesh, col_linear_placement)
            )
            mod.register_parameter(name, dist_param)

sharded_module = distribute_module(MyModule(), mesh, partition_fn=shard_params)
posted @ 2025-03-20 21:18  xwher  阅读(263)  评论(0)    收藏  举报