torch.DistributedDataParallel v.s. DataParallel

下面详细对比 PyTorch 中 DistributedDataParallel(简称 DDP)和 DataParallel 的用法和主要区别:


1. 使用方式

DataParallel 用法(单进程多卡,多卡在同一台机器)

import torch
model = ...  # 实例化模型
model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])
model = model.cuda()

训练时直接像单卡一样用 model.forward()、model.backward()、optimizer.step() 即可。


DistributedDataParallel 用法(多进程多卡,支持单机或多机)

DDP 推荐每个进程只负责一张卡(单进程单卡),通常用 torch.multiprocessing 启动多个进程。

典型启动方式:

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def demo_basic(rank, world_size):
    setup(rank, world_size)
    model = ...  # 实例化模型
    model = model.to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    # 后续训练和单卡类似
    cleanup()

if __name__ == "__main__":
    world_size = 4
    torch.multiprocessing.spawn(demo_basic, args=(world_size,), nprocs=world_size)
  • 需要初始化通信组(dist.init_process_group
  • 每个进程只负责一张卡(model.to(rank)
  • 用 DDP 包裹模型(DDP(model, device_ids=[rank])

2. 主要区别

项目 DataParallel DistributedDataParallel (DDP)
并行方式 单进程多卡 多进程多卡(推荐一进程一卡)
通信方式 主进程负责切分数据,收集梯度,存在单进程瓶颈 多进程通信,梯度同步高效,无主进程瓶颈
适用场景 单机多卡 单机多卡/多机多卡
性能 较低,主进程瓶颈明显 高,几乎线性加速
代码复杂度 简单 略高:需初始化通信组、管理多进程
推荐使用 小模型/实验/快速原型 生产环境/大模型/大数据训练

3. 总结

  • DataParallel 使用简单,但性能受限于主进程,适合简单或小规模实验。
  • DistributedDataParallel (DDP) 性能优越,推荐用于实际生产和大规模训练,但需管理多进程和通信组初始化。

代码片段对比

DataParallel:

model = torch.nn.DataParallel(model)
model = model.cuda()

DistributedDataParallel:

dist.init_process_group(backend='nccl')
model = model.to(local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])

如需进一步了解 DDP 的最佳实践或分布式训练环境配置,欢迎继续提问!

posted @ 2025-05-30 14:16  Gold_stein  阅读(87)  评论(0)    收藏  举报