Pytorch分布式训练Step by Step
DDP Overview
在开始介绍Pytorch分布式训练之前,需要首先理解有关于分布式训练的一些概念:
master node:主GPU,负责同步,模型加载,拷贝,写日志等操作。process group:在K个GPU上训练/测试模型,则K个进程形成一个组,该组由backend支持,Pytorch将会管理backend,推荐使用nccl作为DDP backendrank:在进程组中的每个进程通过rank进行标识,从0到K-1world size:进程组内GPU的个数K
Pytorch为分布式训练提供了两种设置:**torch.nn.DataParallel(DP)** 和 **torch.nn.parallel.DistributedDataParallel(DDP)**,官方推荐使用DDP,因为DDP比DP更快、更灵活。DDP所做的基本工作是将模型复制到多个GPU,收集梯度,平均梯度以更新模型,然后在K个进程上同步模型,还可以通过torch.distributed.gather/scatter/reduce等操作收集/分发tensors/objects
如果模型可以在一个GPU上进行训练,并且想要在K个GPU上训练和测试,DDP的最佳实践是将模型复制到K个GPU上(DDP会自动执行此操作),并将dataloader拆分为K个不重复的组分别输入K个模型。
DDP Step
1. setup the process group
import torch.distributed as dist
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
2. Split the dataloader
我们可以通过torch.utils.data.DistributedSampler轻松split我们的dataloader,sampler返回一个带索引的迭代器,这些索引被送入到dataloader中对数据进行划分
DistributedSampler将数据的总索引划分成world_size份,均匀地分配给每个进程中的dataloader,并且不会重复。
from torch.utils.data.distributed import DistributedSampler
def prepare(rank, world_size, batch_size=32, pin_memory=False, num_workers=0):
dataset = Your_Dataset()
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
return dataloader
假设K=3,数据集的长度为10,DistributedSampler对索引的划分的规则如下:
- 如果我们在定义DistributedSampler的时候,设置
drop_last=False,它将会自动补全,例如:索引[0,1,2,3,4,5,6,7,8,9], (rank=1)=> [0,3,6,9],(rank=2)=>[0,4,7,0], (rank=3)=>[2,5,8,0],这样补全可能会存在问题,因为填充的0是一个数据记录 - 否则,它将去除尾部元素,例如,它分拆索引(rank=1)=> [0,3,6],(rank=2)=>[0,4,7], (rank=3)=>[2,5,8],这种情况下,它丢掉了9被wolrd size整除
自定义我们的Sampler非常简单,我们只需要创建一个类,然后定义它的__iter__()和__len__()函数,有关更多细节,可参考官方文档
Note:最好在分布式训练时将num_workers设置为0,因为在子进程中创建额外的进程可能会有问题,将pin_memory设置为False可以避免一些可怕的bug
3. Wrap the model with DDP
❓ 关于dist.barrier()的疑惑
何时使用dist.barrier()? 根据文档所述,dist.barrier() 同步进程,换句话讲,它阻塞process直到它们都到达dist.barrier()这句代码,dist.barrier()的使用可总结如下:
- 训练过程不需要使用dist.barrier(),因为DDP会自动帮我们处理(in loss.backward())
- 当gathering data的时候也不需要使用,因为dist.all_gather_object会帮我们做这些工作
- 我们在执行有序代码的时候需要使用dist.barrier(),例如一个进程加载另一个进程保存的模型
参考:

浙公网安备 33010602011771号