分布式训练神经网络
DistributedDataParallel
PyTorch 中的 DistributedDataParallel 和 Elastic 是用于支持分布式训练的两个相关组件,常常结合使用。
- DistributedDataParallel 用于在多个 GPU 上实现数据并行的分布式训练。其在每个 GPU 上运行模型的一个副本,并独立进行前向和后向传播。梯度在每次反向传播后进行同步,从而确保所有模型副本保持一致。
- Elastic 用于实现分布式训练的弹性调度和故障恢复。其允许在分布式训练过程中动态地加减节点,并能够容忍节点故障后的自动恢复。
下载数据集
python -c "from torchvision.datasets import MNIST; MNIST(root='data', train=True, download=True)"
模型定义
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 5, kernel_size=3, stride=1, padding=1)
self.norm = nn.BatchNorm2d(5)
self.head = nn.Conv2d(5, 10, kernel_size=28, stride=1, padding=0)
def forward(self, x):
x = self.norm(self.conv(x))
x = self.head(x)
return x.squeeze()
初始化分布式进程组
import torch as th
import torch.distributed as dist
from os import environ
world_size = int(environ['WORLD_SIZE']) # 全局进程数
rank = int(environ['RANK']) # 全局进程号
local_rank = int(environ['LOCAL_RANK']) # 本地进程号
th.cuda.set_device(local_rank) # 设置默认 GPU
# 初始化分布式进程组
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=world_size,
rank=rank,
device_id=local_rank
)
数据采样
使用 DistributedSampler 将数据集根据 GPU 数划分成若干子集。
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
target_transform = transforms.Lambda(lambda x: F.one_hot(th.tensor(x), num_classes=10).float())
dataset_train = MNIST(root='data', train=True, transform=transforms.ToTensor(), target_transform=target_transform)
dataset_val = MNIST(root='data', train=False, transform=transforms.ToTensor())
sampler_train = DistributedSampler(dataset_train, num_replicas=world_size, rank=rank, shuffle=True)
sampler_val = DistributedSampler(dataset_val, num_replicas=world_size, rank=rank, shuffle=False)
data_loader_train = DataLoader(dataset_train, batch_size=256, sampler=sampler_train, num_workers=4, pin_memory=True)
data_loader_val = DataLoader(dataset_val, batch_size=256, sampler=sampler_val, num_workers=4, pin_memory=True)
创建模型
使用 DistributedDataParallel 创建同步模型。
from torch.nn.parallel import DistributedDataParallel as DDP
model = MyModel().cuda()
model = DDP(model, device_ids=[local_rank])
训练
from torch.optim import SGD
from timm.utils import accuracy
from pathlib import Path
criterion = nn.MSELoss()
optimizer = SGD(model.parameters(), lr=0.001)
for epoch in range(10):
model.train()
sampler_train.set_epoch(epoch)
for samples, labels in data_loader_train:
samples, labels = samples.cuda(), labels.cuda()
outputs = model(samples)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 保存检查点
if (epoch % 2 == 0 or epoch + 1 == 10) and rank == 0:
model.eval()
total_acc1 = 0.0
total_acc3 = 0.0
num_batches = 0
with th.inference_mode():
for samples, labels in data_loader_val:
samples, labels = samples.cuda(), labels.cuda()
outputs = model(samples)
acc1, acc3 = accuracy(outputs, labels, topk=(1, 3))
total_acc1 += acc1.item()
total_acc3 += acc3.item()
num_batches += 1
avg_acc1 = total_acc1 / num_batches
avg_acc3 = total_acc3 / num_batches
print(f'Epoch [{epoch}/10] - acc1: {avg_acc1:.4f}, acc3: {avg_acc3:.4f}')
Path('checkpoints').mkdir(exist_ok=True)
th.save(model.state_dict(), f'checkpoints/checkpoint-{epoch}.pt')
# 退出并清理分布式进程组
dist.destroy_process_group()
加载检查点
state_dict = th.load("checkpoint-9.pt", map_location="cpu", weights_only=False)
model.load_state_dict(state_dict)
Elastic 分布式调度
可以使用 torchrun 命令启动分布式训练:
export OMP_NUM_THREADS=1
torchrun \
--nnodes 1 \
--nproc-per-node auto \
--rdzv-id 0 \
--rdzv-backend c10d \
--rdzv-endpoint localhost:0 \
main.py
--nnodes:节点数--nproc-per-node:当前节点上的进程数--rdzv-id:作业号--rdzv-backend:通信后端--rdzv-endpoint:通信端点
如果使用多节点训练,需要在每个节点上执行 torchrun 命令,或者使用 Slurm 等集群作业管理工具。
参考:
- Getting Started with Distributed Data Parallel | PyTorch Tutorials
- Multinode Training | PyTorch Tutorials
- PyTorch Elastic Quickstart | PyTorch documentation
- torchrun (Elastic Launch) | PyTorch documentation
Accelerate
Accelerate 是 Hugging Face 对 PyTorch DDP 的包装,使 DDP 更好用。
pip install accelerate # 安装 accelerate
训练
使用 accelerate 实现分布式训练很简单,只需对原有代码进行以下调整:
+ from accelerate import Accelerator
+ accelerator = Accelerator()
+ model, optimizer, dataloader, scheduler = accelerator.prepare(
+ model, optimizer, dataloader, scheduler
+ )
for samples, labels in dataloader:
- samples = samples.to(device)
- labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
- loss.backward()
+ accelerator.backward(loss)
optimizer.step()
scheduler.step()
启动分布式训练:
accelerate launch \
--multi_gpu \
--main_process_ip '127.0.0.1' \
--main_process_port 29500 \
--num_machines 1 \
--num_processes 8 \
--machine_rank=0 \
--dynamo_backend no \
--mixed_precision bf16 \
train.py
--multi_gpu:启用分布式训练框架--main_process_ip:主进程地址--main_process_port:主进程端口--num_machines:节点数--num_processes:进程数--machine_rank:当前节点序号--dynamo_backend:Dynamo 后端,默认为no--mixed_precision:混合精度模式,可选值{no,fp16,bf16,fp8}
因为 Accelerate 是对 PyTorch DDP 的封装,因此也可以使用 torchrun 启动分布式训练。
参考:Accelerate | Hugging Face Docs
Trainer
Trainer 是最高层次的抽象,我们连训练流程代码都不用写了,只需设置训练参数即可。Trainer 会为我们料理一切:
import torch as th
from transformers import Trainer, TrainingArguments
model = MyModel()
# 设置超参数
training_args = TrainingArguments(
"basic-trainer",
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=1,
evaluation_strategy="epoch",
remove_unused_columns=False
)
def collate_fn(examples):
"""生成训练数据"""
x = th.stack([example[0] for example in examples])
labels = th.tensor([example[1] for example in examples])
return { "x": x, "labels": labels }
class MyTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
"""计算 loss"""
outputs = model(inputs["x"])
targets = inputs["labels"]
loss = F.nll_loss(outputs, targets)
return (loss, outputs) if return_outputs else loss
trainer = MyTrainer(
model,
training_args,
train_dataset=train_dset,
eval_dataset=test_dset,
data_collator=collate_fn
)
trainer.fit()
参考:从 PyTorch DDP 到 Accelerate 到 Trainer,轻松掌握分布式训练 | Hugging Face

浙公网安备 33010602011771号