Pytorch分布式训练Step by Step

DDP Overview

在开始介绍Pytorch分布式训练之前,需要首先理解有关于分布式训练的一些概念:

  • master node:主GPU,负责同步,模型加载,拷贝,写日志等操作。
  • process group:在K个GPU上训练/测试模型,则K个进程形成一个组,该组由backend支持,Pytorch将会管理backend,推荐使用nccl作为DDP backend
  • rank:在进程组中的每个进程通过rank进行标识,从0到K-1
  • world size:进程组内GPU的个数K

Pytorch为分布式训练提供了两种设置:**torch.nn.DataParallel(DP)****torch.nn.parallel.DistributedDataParallel(DDP)**,官方推荐使用DDP,因为DDP比DP更快、更灵活。DDP所做的基本工作是将模型复制到多个GPU,收集梯度,平均梯度以更新模型,然后在K个进程上同步模型,还可以通过torch.distributed.gather/scatter/reduce等操作收集/分发tensors/objects

如果模型可以在一个GPU上进行训练,并且想要在K个GPU上训练和测试,DDP的最佳实践是将模型复制到K个GPU上(DDP会自动执行此操作),并将dataloader拆分为K个不重复的组分别输入K个模型。

DDP Step

1. setup the process group

import torch.distributed as dist

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

2. Split the dataloader

我们可以通过torch.utils.data.DistributedSampler轻松split我们的dataloader,sampler返回一个带索引的迭代器,这些索引被送入到dataloader中对数据进行划分
DistributedSampler将数据的总索引划分成world_size份,均匀地分配给每个进程中的dataloader,并且不会重复。

from torch.utils.data.distributed import DistributedSampler

def prepare(rank, world_size, batch_size=32, pin_memory=False, num_workers=0):
    dataset = Your_Dataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
	
    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
	
    return dataloader

假设K=3,数据集的长度为10,DistributedSampler对索引的划分的规则如下:

  • 如果我们在定义DistributedSampler的时候,设置drop_last=False,它将会自动补全,例如:索引[0,1,2,3,4,5,6,7,8,9], (rank=1)=> [0,3,6,9],(rank=2)=>[0,4,7,0], (rank=3)=>[2,5,8,0],这样补全可能会存在问题,因为填充的0是一个数据记录
  • 否则,它将去除尾部元素,例如,它分拆索引(rank=1)=> [0,3,6],(rank=2)=>[0,4,7], (rank=3)=>[2,5,8],这种情况下,它丢掉了9被wolrd size整除

自定义我们的Sampler非常简单,我们只需要创建一个类,然后定义它的__iter__()和__len__()函数,有关更多细节,可参考官方文档

Note:最好在分布式训练时将num_workers设置为0,因为在子进程中创建额外的进程可能会有问题,将pin_memory设置为False可以避免一些可怕的bug

3. Wrap the model with DDP

❓ 关于dist.barrier()的疑惑

何时使用dist.barrier()? 根据文档所述,dist.barrier() 同步进程,换句话讲,它阻塞process直到它们都到达dist.barrier()这句代码,dist.barrier()的使用可总结如下:

  1. 训练过程不需要使用dist.barrier(),因为DDP会自动帮我们处理(in loss.backward())
  2. 当gathering data的时候也不需要使用,因为dist.all_gather_object会帮我们做这些工作
  3. 我们在执行有序代码的时候需要使用dist.barrier(),例如一个进程加载另一个进程保存的模型

参考:

  1. Pytorch distributed data parallel step by step
posted @ 2022-04-26 18:09  灵客风  阅读(705)  评论(0)    收藏  举报