分布式训练神经网络

DistributedDataParallel

PyTorch 中的 DistributedDataParallelElastic 是用于支持分布式训练的两个相关组件,常常结合使用。

  • DistributedDataParallel 用于在多个 GPU 上实现数据并行的分布式训练。其在每个 GPU 上运行模型的一个副本,并独立进行前向和后向传播。梯度在每次反向传播后进行同步,从而确保所有模型副本保持一致。
  • Elastic 用于实现分布式训练的弹性调度和故障恢复。其允许在分布式训练过程中动态地加减节点,并能够容忍节点故障后的自动恢复。

下载数据集

python -c "from torchvision.datasets import MNIST; MNIST(root='data', train=True, download=True)"

模型定义

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 5, kernel_size=3, stride=1, padding=1)
        self.norm = nn.BatchNorm2d(5)
        self.head = nn.Conv2d(5, 10, kernel_size=28, stride=1, padding=0)

    def forward(self, x):
        x = self.norm(self.conv(x))
        x = self.head(x)
        return x.squeeze()

初始化分布式进程组

import torch as th
import torch.distributed as dist
from os import environ

world_size = int(environ['WORLD_SIZE'])  # 全局进程数
rank = int(environ['RANK'])              # 全局进程号
local_rank = int(environ['LOCAL_RANK'])  # 本地进程号
th.cuda.set_device(local_rank)           # 设置默认 GPU

# 初始化分布式进程组
dist.init_process_group(
    backend='nccl',
    init_method='env://',
    world_size=world_size,
    rank=rank,
    device_id=local_rank
)

数据采样

使用 DistributedSampler 将数据集根据 GPU 数划分成若干子集。

import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST

target_transform = transforms.Lambda(lambda x: F.one_hot(th.tensor(x), num_classes=10).float())

dataset_train = MNIST(root='data', train=True, transform=transforms.ToTensor(), target_transform=target_transform)
dataset_val = MNIST(root='data', train=False, transform=transforms.ToTensor())

sampler_train = DistributedSampler(dataset_train, num_replicas=world_size, rank=rank, shuffle=True)
sampler_val = DistributedSampler(dataset_val, num_replicas=world_size, rank=rank, shuffle=False)

data_loader_train = DataLoader(dataset_train, batch_size=256, sampler=sampler_train, num_workers=4, pin_memory=True)
data_loader_val = DataLoader(dataset_val, batch_size=256, sampler=sampler_val, num_workers=4, pin_memory=True)

创建模型

使用 DistributedDataParallel 创建同步模型。

from torch.nn.parallel import DistributedDataParallel as DDP

model = MyModel().cuda()
model = DDP(model, device_ids=[local_rank])

训练

from torch.optim import SGD
from timm.utils import accuracy
from pathlib import Path

criterion = nn.MSELoss()
optimizer = SGD(model.parameters(), lr=0.001)

for epoch in range(10):
    model.train()
    sampler_train.set_epoch(epoch)

    for samples, labels in data_loader_train:
        samples, labels = samples.cuda(), labels.cuda()

        outputs = model(samples)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 保存检查点
    if (epoch % 2 == 0 or epoch + 1 == 10) and rank == 0:
        model.eval()
        total_acc1 = 0.0
        total_acc3 = 0.0
        num_batches = 0

        with th.inference_mode():
            for samples, labels in data_loader_val:
                samples, labels = samples.cuda(), labels.cuda()
                outputs = model(samples)
                acc1, acc3 = accuracy(outputs, labels, topk=(1, 3))
                total_acc1 += acc1.item()
                total_acc3 += acc3.item()
                num_batches += 1

        avg_acc1 = total_acc1 / num_batches
        avg_acc3 = total_acc3 / num_batches
        print(f'Epoch [{epoch}/10] - acc1: {avg_acc1:.4f}, acc3: {avg_acc3:.4f}')

        Path('checkpoints').mkdir(exist_ok=True)
        th.save(model.state_dict(), f'checkpoints/checkpoint-{epoch}.pt')

# 退出并清理分布式进程组
dist.destroy_process_group()

加载检查点

state_dict = th.load("checkpoint-9.pt", map_location="cpu", weights_only=False)
model.load_state_dict(state_dict)

Elastic 分布式调度

可以使用 torchrun 命令启动分布式训练:

export OMP_NUM_THREADS=1

torchrun \
    --nnodes 1 \
    --nproc-per-node auto \
    --rdzv-id 0 \
    --rdzv-backend c10d \
    --rdzv-endpoint localhost:0 \
    main.py
  • --nnodes:节点数
  • --nproc-per-node:当前节点上的进程数
  • --rdzv-id:作业号
  • --rdzv-backend:通信后端
  • --rdzv-endpoint:通信端点

如果使用多节点训练,需要在每个节点上执行 torchrun 命令,或者使用 Slurm 等集群作业管理工具。

参考:

Accelerate

Accelerate 是 Hugging Face 对 PyTorch DDP 的包装,使 DDP 更好用。

pip install accelerate  # 安装 accelerate

训练

使用 accelerate 实现分布式训练很简单,只需对原有代码进行以下调整:

+ from accelerate import Accelerator
+ accelerator = Accelerator()

+ model, optimizer, dataloader, scheduler = accelerator.prepare(
+     model, optimizer, dataloader, scheduler
+ )

  for samples, labels in dataloader:
-     samples = samples.to(device)
-     labels = labels.to(device)
      outputs = model(inputs)
      loss = criterion(outputs, targets)
      optimizer.zero_grad()
-     loss.backward()
+     accelerator.backward(loss)
      optimizer.step()
      scheduler.step()

启动分布式训练:

accelerate launch \
    --multi_gpu \
    --main_process_ip '127.0.0.1' \
    --main_process_port 29500 \
    --num_machines 1 \
    --num_processes 8 \
    --machine_rank=0 \
    --dynamo_backend no \
    --mixed_precision bf16 \
    train.py
  • --multi_gpu:启用分布式训练框架
  • --main_process_ip:主进程地址
  • --main_process_port:主进程端口
  • --num_machines:节点数
  • --num_processes:进程数
  • --machine_rank:当前节点序号
  • --dynamo_backend:Dynamo 后端,默认为 no
  • --mixed_precision:混合精度模式,可选值 {no,fp16,bf16,fp8}

因为 Accelerate 是对 PyTorch DDP 的封装,因此也可以使用 torchrun 启动分布式训练。

参考:Accelerate | Hugging Face Docs

Trainer

Trainer 是最高层次的抽象,我们连训练流程代码都不用写了,只需设置训练参数即可。Trainer 会为我们料理一切:

import torch as th
from transformers import Trainer, TrainingArguments

model = MyModel()

# 设置超参数
training_args = TrainingArguments(
    "basic-trainer",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=1,
    evaluation_strategy="epoch",
    remove_unused_columns=False
)

def collate_fn(examples):
    """生成训练数据"""
    x = th.stack([example[0] for example in examples])
    labels = th.tensor([example[1] for example in examples])
    return { "x": x, "labels": labels }

class MyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        """计算 loss"""
        outputs = model(inputs["x"])
        targets = inputs["labels"]
        loss = F.nll_loss(outputs, targets)
        return (loss, outputs) if return_outputs else loss

trainer = MyTrainer(
    model,
    training_args,
    train_dataset=train_dset,
    eval_dataset=test_dset,
    data_collator=collate_fn
)

trainer.fit()

参考:从 PyTorch DDP 到 Accelerate 到 Trainer,轻松掌握分布式训练 | Hugging Face

posted @ 2025-01-21 17:26  Undefined443  阅读(126)  评论(0)    收藏  举报