ray + nccl + 张量并行 训练
数据并行
首先要有一个分布式的环境, 比如ray
有一个集合通讯的后端 比如nccl
将模型初始化为DDP数据并行模型,并指定全局rank
数据使分布式采样
进行模型训练
假如ray集群有两个节点,主节点ip=10.230.40.150 , ray集群集群启动时,每个节点分配一张显卡,
import os
os.environ["RAY_DEDUP_LOGS"] = "0"
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
import ray
import time
def init_distributed_mode(rank, world_size):
print("init start")
device = torch.device('cuda')
dist.init_process_group(
backend='nccl',
init_method='tcp://10.230.40.150:23456',
world_size=world_size,
rank=rank
)
print("init end")
return device
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
def setup_model(rank, device):
model = SimpleModel().to(device)
return DDP(model, device_ids=[device])
def setup_dataloader(rank, world_size, batch_size=32):
dataset = TensorDataset(torch.randn(1000, 10), torch.randn(1000, 1))
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
return dataloader
@ray.remote(num_gpus=1) # 确保每个任务分配GPU
def train(rank, world_size):
device = init_distributed_mode(rank, world_size)
model = setup_model(rank, device)
dataloader = setup_dataloader(rank, world_size)
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
for epoch in range(5):
time.sleep(60)
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
print(f"Rank {rank}, Epoch {epoch + 1}, Loss: {loss.item()}")
dist.destroy_process_group()
if __name__ == "__main__":
world_size = 2
ray.init(address="auto")
remote_funcs = [train.remote(rank, world_size) for rank in range(world_size)]
ray.get(remote_funcs)
ray.shutdown()
如果world_size设置为4,
- 只有 2 个任务能获得 GPU 并启动,另外 2 个会无限等待(因为 num_gpus=1 是硬性要求)。
- 已启动的 2 个任务会因 NCCL 等待其他 rank(world_size=4)而卡在 dist.init_process_group()。
(train pid=63009) init start (train pid=37071, ip=10.230.3.42) init start (train pid=37071, ip=10.230.3.42) init end然后直到 出现错误: torch.distributed.DistBackendError: [1] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: Socket Timeout
张量并行推理
如果主节点 在tcp://10.230.20.229:23456上
import os
os.environ["RAY_DEDUP_LOGS"] = "0"
import ray
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import time
class ColumnParallelLinear(nn.Module):
def __init__(self, input_size, output_size, rank, world_size, is_gather = False):
super(ColumnParallelLinear, self).__init__()
self.world_size = world_size if world_size is not None else dist.get_world_size()
self.rank = rank
self.input_size = input_size
self.output_size = output_size
self.local_output_size = output_size // self.world_size
self.weight = nn.Parameter(torch.empty(self.local_output_size, input_size))
self.bias = nn.Parameter(torch.empty(self.local_output_size))
self.is_gather = is_gather
def init_weight(self, offset):
with torch.no_grad():
weight_value = torch.arange(0, self.weight.numel() * self.world_size).reshape(self.output_size, self.input_size)[self.rank * self.local_output_size : (self.rank + 1) * self.local_output_size, :]
self.weight.copy_(weight_value)
bias_value = torch.arange(0, self.bias.numel() * self.world_size)[self.rank * self.local_output_size : (self.rank + 1) * self.local_output_size]
self.bias.copy_(bias_value)
self.weight += offset
self.bias += offset
def forward(self, x):
output_ = F.linear(x, self.weight) + self.bias
if self.is_gather:
ouptut = [torch.zeros_like(output_) for _ in range(self.world_size)]
dist.all_gather(ouptut, output_)
ouptut = torch.cat(ouptut, dim = -1)
else:
output = output_
return output
class RowParallelLinear(nn.Module):
def __init__(self, input_size, output_size, rank, world_size, is_reduce = True):
super(RowParallelLinear, self).__init__()
self.world_size = world_size if world_size is not None else dist.get_world_size()
self.rank = rank
self.input_size = input_size
self.output_size = output_size
self.local_input_size = input_size // self.world_size
self.weight = nn.Parameter(torch.empty(output_size, self.local_input_size))
self.bias = nn.Parameter(torch.empty(output_size))
self.is_reduce = is_reduce
def init_weight(self, offset):
with torch.no_grad():
weight_values = torch.arange(0, self.weight.numel() * self.world_size).reshape(self.output_size, self.input_size)[:, self.rank * self.local_input_size : (self.rank+1) * self.local_input_size]
self.weight.copy_(weight_values)
bias_value = torch.arange(0, self.bias.numel()).float()
self.bias.copy_(bias_value)
self.weight += offset
self.bias += offset
def forward(self, x):
output_ = F.linear(x, self.weight)
if self.is_reduce:
dist.all_reduce(output_, op = dist.ReduceOp.SUM)
output = output_ + self.bias
return output
class MultiHeadAttentionTP(nn.Module):
def __init__(self, embed_dim, num_heads, bias, rank, world_size):
super().__init__()
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.node_num_heads = self.total_num_heads // world_size
self.node_embed_dim = self.embed_dim // world_size
self.head_dim = self.embed_dim // self.total_num_heads
self.scalling = self.head_dim ** -0.5
self.q_proj = ColumnParallelLinear(self.embed_dim, self.embed_dim, rank, world_size, False)
self.k_proj = ColumnParallelLinear(self.embed_dim, self.embed_dim, rank, world_size, False)
self.v_proj = ColumnParallelLinear(self.embed_dim, self.embed_dim, rank, world_size, False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, rank, world_size, True)
self.init_weight()
def init_weight(self):
self.q_proj.init_weight(0.1)
self.k_proj.init_weight(-0.2)
self.v_proj.init_weight(0.3)
self.out_proj.init_weight(-0.4)
# q, k, v: [num_seq, node_num_heads, seq_len, head_dim]def attention(self, q, k, v):
def attention(self,q, k, v):
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scalling # [num_seq, node_num_heads, seq_len, seq_len]
attn_weights = torch.nn.functional.softmax(scores, dim = -1) # [num_seq, node_num_heads, seq_len, seq_len]
attn_output = torch.matmul(attn_weights, v) # [num_seq, node_num_heads, seq_len, head_dim]return attn_output
return attn_output
def forward(self, hidden_state):
num_seq, seq_len, embed_dim = hidden_state.size() # hidden_state : [num_seq, seq_len, embed_dim]
q = self.q_proj(hidden_state) # [num_seq, seq_len, node_embed_dim], embed_dim // world_size = node_num_heads * head_dim = node_embed_dim
k = self.k_proj(hidden_state) # [num_seq, seq_len, node_embed_dim], embed_dim // world_size = node_num_heads * head_dim = node_embed_dim
v = self.v_proj(hidden_state) # [num_seq, seq_len, node_embed_dim], embed_dim // world_size = node_num_heads * head_dim = node_embed_dim
q = q.view(num_seq, seq_len, self.node_num_heads, self.head_dim).transpose(1,2) # [num_seq, node_num_heads, seq_len, head_dim]
k = k.view(num_seq, seq_len, self.node_num_heads, self.head_dim).transpose(1,2) # [num_seq, node_num_heads, seq_len, head_dim]
v = v.view(num_seq, seq_len, self.node_num_heads, self.head_dim).transpose(1,2) # [num_seq, node_num_heads, seq_len, head_dim]
attn_output = self.attention(q, k, v) # [num_seq, node_num_heads, seq_len, head_dim]
attn_output = attn_output.transpose(1,2).contiguous().view(num_seq, seq_len, self.node_embed_dim)
output = self.out_proj(attn_output)
return output
world_size = 2
def init_distributed_mode(rank):
print("init start")
device = torch.device("cuda")
dist.init_process_group(
backend='nccl',
init_method='tcp://10.230.20.229:23456',
world_size=world_size,
rank=rank
)
print("init end")
return device
class SimpleModel(nn.Module):
def __init__(self, rank, world_size):
super(SimpleModel, self).__init__()
self.attn = MultiHeadAttentionTP(512 , 8 , True, rank, world_size)
def forward(self, x):
out = self.attn(x)
return out
@ray.remote(num_gpus=1)
def example(rank):
time.sleep(5*60)
device = init_distributed_mode(rank)
with torch.no_grad():
model = SimpleModel(rank, world_size).to(device).eval()
if rank == 0:
torch.manual_seed(42)
x = torch.randn(128, 8, 512 ).to(device)
else:
x = torch.empty(128, 8, 512 ).to(device)
dist.broadcast(x, src = 0)
out = model(x)
print(out.mean())
dist.destroy_process_group()
if __name__ == "__main__":
ray.init(address = "auto")
remotes = [example.remote(rank) for rank in range(world_size)]
ray.get(remotes)
# 清理
ray.shutdown()
知识是我们已知的
也是我们未知的
基于已有的知识之上
我们去发现未知的
由此,知识得到扩充
我们获得的知识越多
未知的知识就会更多
因而,知识扩充永无止境

浙公网安备 33010602011771号