PyTorch DTensor解释
Pytorch DTensor
我们引入分布式tensor原语,可以更容易地使用SPMD(single program multi devices)模式来编写分布式计算。这些原语可以表达出 shard 和 replicate 的概念。一个例子如下:
# run command: uv run torchrun --standalone --nnodes=1 --nproc-per-node=4 example.py
import os
import torch
from torch.distributed.tensor import init_device_mesh, Shard, distribute_tensor
torch.random.manual_seed(41) # let each device owns the same tensor in cpu.
mesh = init_device_mesh("cuda", (int(os.environ["WORLD_SIZE"]),))
big_tensor = torch.randn(100000, 88)
my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(dim=0)])
idx = mesh.get_rank()
print(f"cur rank: {idx}\n {my_dtensor}")
process_group = mesh.get_group()
output = [torch.zeros(100000 // 4, 88).cuda() for _ in range(4)]
process_group.allgather(output, my_dtensor.to_local()).wait()
output = [ o.cpu() for o in output]
res = torch.cat(output)
if torch.allclose(res, big_tensor):
print("Success!!!")
else:
print(res, big_tensor)
print("Exist bug!!!")
在这个例子中,我们初始化在每个进程初始化一个很大的tensor,这里我们固定种子是为了验证allgather
操作的正确性,可以使得这个大tensor都是相同的;我们可以从mesh中得到当前进程的rank,根据rank我们获得部分的local tensor;注意我们需要知道rank为0的进程负责划分这个大tensor;之后我们可以使用集合通信原语得到整个的大tensor;最后可以验证我们的划分结果是正确的。
DTensor的优势
- 提供一个统一的方式来 save 或 load
state_dict
- 可以在eager模式开启tp,更加灵活
- 作为 SPMD 编程模型的起点
PyTorch DTensor
API
下面的例子演示一些基础的DTensor操作:
- 如何直接构造一个DTensor,来表示不同类型的sharding, replication, sharding + replication
- 如何从local的
torch.Tensor
创建DTensor - 如何将已有的
DTensor
reshard 成 修改放置策略的 DTensor.
# uv run torchrun --standalone --nnodes=1 --nproc-per-node=4 dtensor_example.py
import torch
from torch.distributed.tensor import DTensor, Shard, Replicate, distribute_tensor, distribute_module, init_device_mesh
# construct a device mesh with available devices (multi-host or single host)
device_mesh = init_device_mesh("cuda", (4,))
# if we want to do row-wise sharding
rowwise_placement=[Shard(0)]
# if we want to do col-wise sharding
colwise_placement=[Shard(1)]
big_tensor = torch.randn(888, 12)
# distributed tensor returned will be sharded across the dimension specified in placements
rowwise_tensor = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=rowwise_placement)
# if we want to do replication across a certain device list
replica_placement = [Replicate()]
# distributed tensor will be replicated to all four GPUs.
replica_tensor = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=replica_placement)
# if we want to distributed a tensor with both replication and sharding
device_mesh = init_device_mesh("cuda", (2, 2))
# replicate across the first dimension of device mesh, then sharding on the second dimension of device mesh
spec=[Replicate(), Shard(0)]
partial_replica = distribute_tensor(big_tensor, device_mesh=device_mesh, placements=spec)
# create a DistributedTensor that shards on dim 0, from a local torch.Tensor
local_tensor = torch.randn((8, 8), requires_grad=True)
rowwise_tensor = DTensor.from_local(local_tensor, device_mesh, rowwise_placement)
# reshard the current row-wise tensor to a colwise tensor or replicate tensor
colwise_tensor = rowwise_tensor.redistribute(device_mesh, colwise_placement)
replica_tensor = colwise_tensor.redistribute(device_mesh, replica_placement)
注意上面的API面向更底层的tensor,我们可以使用面向Module的更高层API
import torch.nn as nn
from torch.distributed.tensor import Shard, distribute_tensor, distribute_module, init_device_mesh
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(8, 8)
self.fc2 = nn.Linear(8, 8)
self.relu = nn.ReLU()
def forward(self, input):
return self.relu(self.fc1(input) + self.fc2(input))
mesh = init_device_mesh("cuda", (4,))
def shard_params(mod_name, mod, mesh):
col_linear_placement = [Shard(0)]
# shard fc1 and fc2
if isinstance(mod, nn.Linear):
for name, param in mod.named_parameters():
dist_param = nn.Parameter(
distribute_tensor(param, mesh, col_linear_placement)
)
mod.register_parameter(name, dist_param)
sharded_module = distribute_module(MyModule(), mesh, partition_fn=shard_params)