ray + nccl + 张量并行 训练

数据并行

首先要有一个分布式的环境, 比如ray
有一个集合通讯的后端 比如nccl
将模型初始化为DDP数据并行模型,并指定全局rank
数据使分布式采样
进行模型训练
假如ray集群有两个节点,主节点ip=10.230.40.150 , ray集群集群启动时,每个节点分配一张显卡,

import os
os.environ["RAY_DEDUP_LOGS"] = "0"
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
import ray
import time

def init_distributed_mode(rank, world_size):
    print("init start")
    device = torch.device('cuda')
    dist.init_process_group(
        backend='nccl', 
        init_method='tcp://10.230.40.150:23456', 
        world_size=world_size, 
        rank=rank
    )
    print("init end")
    return device

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

def setup_model(rank, device):
    model = SimpleModel().to(device)
    return DDP(model, device_ids=[device])

def setup_dataloader(rank, world_size, batch_size=32):
    dataset = TensorDataset(torch.randn(1000, 10), torch.randn(1000, 1))
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return dataloader

@ray.remote(num_gpus=1) # 确保每个任务分配GPU
def train(rank, world_size):
    device = init_distributed_mode(rank, world_size)
    model = setup_model(rank, device)
    dataloader = setup_dataloader(rank, world_size)

    optimizer = optim.SGD(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()

    for epoch in range(5):
        time.sleep(60)
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
        print(f"Rank {rank}, Epoch {epoch + 1}, Loss: {loss.item()}")

    dist.destroy_process_group()

if __name__ == "__main__":
    world_size = 2
    ray.init(address="auto")  
    remote_funcs = [train.remote(rank, world_size) for rank in range(world_size)]
    ray.get(remote_funcs)
    ray.shutdown()

如果world_size设置为4,

  • 只有 2 个任务能获得 GPU 并启动,另外 2 个会​​无限等待​​(因为 num_gpus=1 是硬性要求)。
  • 已启动的 2 个任务会因 NCCL 等待其他 rank(world_size=4)而​​卡在 dist.init_process_group()​​。
(train pid=63009) init start
(train pid=37071, ip=10.230.3.42) init start
(train pid=37071, ip=10.230.3.42) init end

然后直到 出现错误: torch.distributed.DistBackendError: [1] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: Socket Timeout

张量并行推理

如果主节点 在tcp://10.230.20.229:23456上

import os
os.environ["RAY_DEDUP_LOGS"] = "0"
import ray
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import time


class ColumnParallelLinear(nn.Module):
    def __init__(self, input_size, output_size, rank, world_size, is_gather = False):
        super(ColumnParallelLinear, self).__init__()
        self.world_size = world_size if world_size is not None else dist.get_world_size()
        self.rank = rank
        self.input_size = input_size
        self.output_size = output_size
        self.local_output_size = output_size // self.world_size
        self.weight = nn.Parameter(torch.empty(self.local_output_size, input_size))
        self.bias = nn.Parameter(torch.empty(self.local_output_size))
        self.is_gather = is_gather

    def init_weight(self, offset):
        with torch.no_grad():
            weight_value = torch.arange(0, self.weight.numel() * self.world_size).reshape(self.output_size, self.input_size)[self.rank * self.local_output_size : (self.rank + 1) * self.local_output_size, :]
            self.weight.copy_(weight_value)
            bias_value = torch.arange(0, self.bias.numel() * self.world_size)[self.rank * self.local_output_size : (self.rank + 1) * self.local_output_size]
            self.bias.copy_(bias_value)
            self.weight += offset
            self.bias += offset

    def forward(self, x):
        output_ = F.linear(x, self.weight) + self.bias
        if self.is_gather:
            ouptut = [torch.zeros_like(output_) for _ in range(self.world_size)]
            dist.all_gather(ouptut, output_)
            ouptut = torch.cat(ouptut, dim = -1)
        else:
            output = output_
        return output



class RowParallelLinear(nn.Module):
    def __init__(self, input_size, output_size, rank, world_size, is_reduce = True):
        super(RowParallelLinear, self).__init__()
        self.world_size = world_size if world_size is not None else dist.get_world_size()
        self.rank = rank
        self.input_size = input_size
        self.output_size = output_size
        self.local_input_size = input_size // self.world_size
        self.weight = nn.Parameter(torch.empty(output_size, self.local_input_size))
        self.bias = nn.Parameter(torch.empty(output_size))
        self.is_reduce = is_reduce
    
    def init_weight(self, offset):
        with torch.no_grad():
            weight_values = torch.arange(0, self.weight.numel() * self.world_size).reshape(self.output_size, self.input_size)[:, self.rank * self.local_input_size : (self.rank+1) * self.local_input_size]
            self.weight.copy_(weight_values)
            bias_value = torch.arange(0, self.bias.numel()).float()
            self.bias.copy_(bias_value)
            self.weight += offset
            self.bias += offset
        
    def forward(self, x):
        output_ = F.linear(x, self.weight)
        if self.is_reduce:
            dist.all_reduce(output_, op = dist.ReduceOp.SUM)
        output = output_ + self.bias
        return output



class MultiHeadAttentionTP(nn.Module):
    def __init__(self, embed_dim, num_heads, bias, rank, world_size): 
        super().__init__()
        self.embed_dim = embed_dim
        self.total_num_heads = num_heads
        self.node_num_heads = self.total_num_heads // world_size
        self.node_embed_dim = self.embed_dim // world_size
        self.head_dim = self.embed_dim // self.total_num_heads
        self.scalling = self.head_dim ** -0.5
        self.q_proj = ColumnParallelLinear(self.embed_dim, self.embed_dim, rank, world_size, False)
        self.k_proj = ColumnParallelLinear(self.embed_dim, self.embed_dim, rank, world_size, False)
        self.v_proj = ColumnParallelLinear(self.embed_dim, self.embed_dim, rank, world_size, False)
        self.out_proj = RowParallelLinear(embed_dim, embed_dim, rank, world_size, True)
        self.init_weight()

    def init_weight(self):
        self.q_proj.init_weight(0.1)
        self.k_proj.init_weight(-0.2)
        self.v_proj.init_weight(0.3)
        self.out_proj.init_weight(-0.4)



    # q, k, v: [num_seq,  node_num_heads, seq_len, head_dim]def attention(self, q, k, v):
    def attention(self,q, k, v):
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scalling # [num_seq,  node_num_heads, seq_len, seq_len]
        attn_weights = torch.nn.functional.softmax(scores, dim = -1) # [num_seq,  node_num_heads, seq_len, seq_len]
        attn_output = torch.matmul(attn_weights, v) # [num_seq,  node_num_heads, seq_len, head_dim]return attn_output
        return attn_output

    
    def forward(self, hidden_state):
        num_seq, seq_len, embed_dim = hidden_state.size() # hidden_state : [num_seq, seq_len, embed_dim]
        q = self.q_proj(hidden_state) # [num_seq, seq_len, node_embed_dim], embed_dim // world_size = node_num_heads * head_dim = node_embed_dim
        k = self.k_proj(hidden_state) # [num_seq, seq_len, node_embed_dim], embed_dim // world_size = node_num_heads * head_dim = node_embed_dim
        v = self.v_proj(hidden_state) # [num_seq, seq_len, node_embed_dim], embed_dim // world_size = node_num_heads * head_dim = node_embed_dim

        q = q.view(num_seq, seq_len, self.node_num_heads, self.head_dim).transpose(1,2) # [num_seq, node_num_heads, seq_len,  head_dim]
        k = k.view(num_seq, seq_len, self.node_num_heads, self.head_dim).transpose(1,2) # [num_seq, node_num_heads, seq_len,  head_dim]
        v = v.view(num_seq, seq_len, self.node_num_heads, self.head_dim).transpose(1,2)  # [num_seq, node_num_heads, seq_len,  head_dim]

        attn_output = self.attention(q, k, v) # [num_seq,  node_num_heads, seq_len, head_dim]

        attn_output = attn_output.transpose(1,2).contiguous().view(num_seq, seq_len, self.node_embed_dim)
        output = self.out_proj(attn_output)
        return output
    
world_size = 2
def init_distributed_mode(rank):
    print("init start")
    device = torch.device("cuda")
    dist.init_process_group(
        backend='nccl', 
        init_method='tcp://10.230.20.229:23456', 
        world_size=world_size, 
        rank=rank
    )
    print("init end")
    return device

class SimpleModel(nn.Module):
    def __init__(self, rank, world_size):
        super(SimpleModel, self).__init__()
        self.attn = MultiHeadAttentionTP(512 , 8 , True, rank, world_size)
    
    def forward(self, x):
        out = self.attn(x)
        return out
    
@ray.remote(num_gpus=1)
def example(rank):
    time.sleep(5*60)
    device = init_distributed_mode(rank)
    with torch.no_grad():
        model = SimpleModel(rank, world_size).to(device).eval()
        if rank == 0:
            torch.manual_seed(42)
            x = torch.randn(128, 8, 512 ).to(device)
        else:
            x = torch.empty(128, 8, 512 ).to(device)
        dist.broadcast(x, src = 0)
        out = model(x)
        print(out.mean())
    dist.destroy_process_group()

if __name__ == "__main__":
    ray.init(address = "auto")
    remotes = [example.remote(rank) for rank in range(world_size)]
    ray.get(remotes)
    # 清理
    ray.shutdown()

posted @ 2025-05-23 20:33  xiezhengcai  阅读(139)  评论(0)    收藏  举报