[源码解析] PyTorch 分布式(14) --使用 Distributed Autograd 和 Distributed Optimizer

[源码解析] PyTorch 分布式(14) --使用 Distributed Autograd 和 Distributed Optimizer

0x00 摘要

在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,接下来我们通过几篇文章来看看如何把这些模块应用到实践之中,顺便把PyTorch分布式逻辑整体梳理一下。本文介绍如何把分布式自动微分和分布式优化器结合起来训练一个模型。

本文以 https://pytorch.org/tutorials/intermediate/rpc_tutorial.html 的部分翻译为基础,加入了自己的理解。

PyTorch分布式其他文章如下:

深度学习利器之自动微分(1)

深度学习利器之自动微分(2)

[源码解析]深度学习利器之自动微分(3) --- 示例解读

[源码解析]PyTorch如何实现前向传播(1) --- 基础类(上)

[源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)

[源码解析] PyTorch如何实现前向传播(3) --- 具体实现

[源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎

[源码解析] Pytorch 如何实现后向传播 (2)---- 引擎静态结构

[源码解析] Pytorch 如何实现后向传播 (3)---- 引擎动态逻辑

[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

[源码解析] PyTorch 分布式(1)------历史和概述

[源码解析] PyTorch 分布式(2) ----- DataParallel(上)

[源码解析] PyTorch 分布式(3) ----- DataParallel(下)

[源码解析] PyTorch 分布式(4)------分布式应用基础概念

[源码解析] PyTorch分布式(5) ------ DistributedDataParallel 总述&如何使用

[源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store

[源码解析] PyTorch 分布式(7) ----- DistributedDataParallel 之进程组

[源码解析] PyTorch 分布式(8) -------- DistributedDataParallel之论文篇

[源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

[源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构

[源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作

[源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播

[源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播

[源码解析] PyTorch 分布式 Autograd (1) ---- 设计

[源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础

[源码解析] PyTorch 分布式 Autograd (3) ---- 上下文相关

[源码解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

[源码解析] PyTorch 分布式 Autograd (5) ---- 引擎(上)

[源码解析] PyTorch 分布式 Autograd (6) ---- 引擎(下)

[源码解析] PyTorch分布式优化器(1)----基石篇

[源码解析] PyTorch分布式优化器(2)----数据并行优化器

[源码解析] PyTorch分布式优化器(3)---- 模型并行

0x01 说明

首先要做一下说明,原文有两部分:强化学习和RNN,本文只是翻译了RNN部分。而且本文没有完全按照原文顺序进行翻译,而是按照自己理解的思路重新组织了文章,用一种从上至下的角度来看这个系统。

本文使用RNN模型来展示如何使用RPC API构建分布式模型并行训练。示例RNN模型非常小,可以很容易地放入单个GPU中,但我们仍然将它的层分在两个不同worker来之上来演示如何分布式训练。开发人员可以应用类似的技术在多个设备和机器上分发更大的模型。

注:在官方这些分布式文章中,worker 有时指代分布式系统之中所有进程,而实际训练进程往往叫做 trainer,本文的worker 就包括一个 trainer 和 一个参数服务器。

0x02 启动

在启动阶段,run_worker 方法会启动一个 trainer 和 一个参数服务器,参数服务器在代码之中没有任何行为。

def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 1:
        # 启动了trainer
        rpc.init_rpc("trainer", rank=rank, world_size=world_size)
        # trainer 业务逻辑
        _run_trainer()
    else:
        # 启动了参数服务器
        rpc.init_rpc("ps", rank=rank, world_size=world_size)
        # parameter server do nothing
        pass

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = 2
    mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)

具体如下图:

           torch.multiprocessing.spawn
                      +
                      |
                      |
    +-----------------+--------------------+
    |                                      |
    |                                      |
    v                                      v
+---+---------------------+   +------------+-------------+
| "ps"          rank = 0  |   | "trainer"      rank = 1  |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
|                         |   |                          |
+-------------------------+   +--------------------------+

0x03 Trainer

我们接下来看看训练循环。初始化模型参数后,我们创建"RNNModel"和"DistributedOptimizer"。分布式优化器将获取参数"RRefs"的列表,查找这些参数所有的不同的 owner workers,并使用给定参数(即"lr=0.05")在每个owner worker上创建给定的本地优化器(在本例中即"SGD",您也可以使用其他本地优化器)。

在训练循环中,它做如下操作:

  • 首先创建分布式autograd context,这将帮助分布式autograd引擎查找梯度和涉及的RPC send/recv 函数。
  • 然后,它像本地模型一样开始向前传播,并且运行分布式向后传播。对于分布式后向传播,您只需要指定根的列表(list of roots),在本例中,它是loss 张量。分布式autograd引擎将自动遍历分布式计算图并正确写入梯度。
  • 接下来,它在分布式优化器上运行'step'函数,该函数将与所有相关的本地优化器联系以更新模型参数。与本地训练相比,一个区别是用户不需要运行 zero_grad() ,因为每个autograd context 都有专用的空间来存储梯度,这样每次迭代创建一个上下文时,来自不同迭代的梯度不会累积到同一组张量之上。

具体代码如下:

def run_trainer():
    batch = 5
    ntoken = 10
    ninp = 2
    nhid = 3
    nindices = 3
    nlayers = 4
    hidden = (
        torch.randn(nlayers, nindices, nhid),
        torch.randn(nlayers, nindices, nhid)
    )

    model = rnn.RNNModel('ps', ntoken, ninp, nhid, nlayers)

    # setup distributed optimizer
    opt = DistributedOptimizer( # 创建分布式优化器
        optim.SGD,
        model.parameter_rrefs(),
        lr=0.05,
    )

    criterion = torch.nn.CrossEntropyLoss()

    def get_next_batch():
        for _ in range(5):
            data = torch.LongTensor(batch, nindices) % ntoken
            target = torch.LongTensor(batch, ntoken) % nindices
            yield data, target

    # train for 10 iterations
    for epoch in range(10):
        for data, target in get_next_batch():
            # create distributed autograd context
            with dist_autograd.context() as context_id: # 创建分布式上下文
                hidden[0].detach_()
                hidden[1].detach_()
                output, hidden = model(data, hidden)
                loss = criterion(output, target)
                # run distributed backward pass
                dist_autograd.backward(context_id, [loss]) # 执行分布式后向传播
                # run distributed optimizer
                opt.step(context_id) # 分布式优化器进行更新
                # not necessary to zero grads since they are
                # accumulated into the distributed autograd context
                # which is reset every iteration.
        print("Training epoch {}".format(epoch))

逻辑扩展为:

           torch.multiprocessing.spawn
                      +
                      |
                      |
    +-----------------+--------------------+
    |                                      |
    |                                      |
    v                                      v
+---+---------------------+   +------------+-----------------------------------+
| "ps"          rank = 0  |   | "trainer"      rank = 1                        |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |    model = rnn.RNNModel('ps')                  |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |    dist_autograd.backward(context_id, [loss])  |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |    DistributedOptimizer.step(context_id)       |
|                         |   |                                                |
|                         |   |                                                |
|                         |   |                                                |
+-------------------------+   +------------------------------------------------+

0x04 模型

我们接下来看看具体模型。

4.1 组件

RNN模型设计借鉴了PyTorch示例库 example中的word语言模型,该模型包含三个主要组件:嵌入表、LSTM层和解码器。

4.1.1 参考代码

我们有必要贴出原始参考代码来比对,可以看到,Embedding 和 Linear 都是作为 RNNModel 的成员变量存在,整个 RNNModel 耦合的非常紧密。

class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False):
        super(RNNModel, self).__init__()
        self.ntoken = ntoken
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp) # 嵌入表成员变量
        if rnn_type in ['LSTM', 'GRU']:
            self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        else:
            nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
            self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken) # 解码器成员变量

			  # 省略后部分代码

4.1.2 分布式修改

我们看看如何依据分布式的特点来对上面模型进行修改。

下面的代码将嵌入表(embedding table)和解码器包装到子模块(sub-modules)中,以便将它们的构造函数传递给RPC API。在EmbeddingTable子模块中,我们有意将嵌入层放在GPU上以做演示。在v1.4中,RPC总是在目标工作进程上创建CPU张量参数或返回值。如果函数采用GPU张量,则需要显式地将其移动到适当的设备。

class EmbeddingTable(nn.Module):
    r"""
    Encoding layers of the RNNModel
    """
    def __init__(self, ntoken, ninp, dropout):
        super(EmbeddingTable, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp).cuda()
        self.encoder.weight.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        return self.drop(self.encoder(input.cuda()).cpu()


class Decoder(nn.Module):
    def __init__(self, ntoken, nhid, dropout):
        super(Decoder, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-0.1, 0.1)

    def forward(self, output):
        return self.decoder(self.drop(output))

4.2 RNN 模型

前面提到,为了实现分布式模型并行训练,开发人员可以将模型划分为子模块。有了上面的子模块,我们现在可以使用RPC将它们组合在一起,创建一个RNN模型。我们将调用RPC远程创建子模块实例,并在必要时使用RRef查找它们。正如您在下面的代码中所看到的,它看起来非常类似于单机模型并行训练。主要区别在于用RPC函数替换 Tensor.to(device)

ps表示一个参数服务器,它承载嵌入表和解码器的参数。构造函数使用remote API在参数服务器上创建EmbeddingTable对象和解码器对象,并在本地创建LSTM子模块。

在向前传播过程中,trainer使用EmbeddingTable RRef查找远程子模块,并使用RPC将输入数据传递给EmbeddingTable并获取查找结果。然后,它通过本地LSTM层运行嵌入,最后使用另一个RPC将输出发送到解码器子模块。

class RNNModel(nn.Module):
    def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()

        # setup embedding table remotely
        self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
        # setup LSTM locally
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        # setup decoder remotely
        self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))

    def forward(self, input, hidden):
        # pass input to the remote embedding table and fetch emb tensor back
        emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
        output, hidden = self.rnn(emb, hidden)
        # pass output to the rremote decoder and get the decoded output back
        decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
        return decoded, hidden

因此,逻辑图拓展如下:

                 torch.multiprocessing.spawn
                            +
                            |
                            |
          +-----------------+--------------------+
          |                                      |
          |                                      |
          v                                      v
+---------+------------+   +---------------------+-------------------------------------+
|"ps"        rank = 0  |   | "trainer"      rank = 1                                   |
|                      |   |                                                           |
|                      |   |   model = rnn.RNNModel('ps')                              |
|                      |   |                                                           |
| +---------------+    |   |   +---------------------------------------+               |
| |EmbeddingTable |    |   |   | RNNModel                              |               |
| |               |    |   |   |                                       |               |
| |               | <--------------+ self.emb_table_rref               |               |
| +---------------+    |   |   |                                       |               |
| +---------------+    |   |   |                                       |               |
| |Decoder        | <--------------+ self.decoder_rref                 |               |
| |               |    |   |   |                                       |               |
| |               |    |   |   |     self.rnn = LSTM                   |               |
| |               |    |   |   |                                       |               |
| +---------------+    |   |   +---------------------------------------+               |
|                      |   |                                                           |
|                      |   |                                                           |
|                      |   |   forward() {                                             |
|                      |   |       emb = _remote_method(EmbeddingTable.forward, input) |
|                      |   |       output, hidden = self.rnn(emb, hidden)              |
+----------------------+   |       decoded = _remote_method(Decoder.forward, output)   |
                           |   }                                                       |
                           |                                                           |
                           |                                                           |
                           |   dist_autograd.backward(context_id, [loss])              |
                           |                                                           |
                           |                                                           |
                           |   DistributedOptimizer.step(context_id)                   |
                           |                                                           |
                           +-----------------------------------------------------------+


4.3 分布式优化器

在介绍分布式优化器之前,让我们添加一个helper函数,此函数用来生成模型参数的RRefs列表,分布式优化器将使用该列表。在本地训练中,应用程序可以调用 Module.parameters()来获取对所有参数张量的引用,并将其传递给本地优化器进行后续更新。但是,由于某些参数存在于远程机器上,因此同一API在分布式训练场景中不起作用。因此,分布式优化器不采用参数"张量"列表,而是采用"RRef"列表,本地和远程模型参数的每个模型参数都有一个"RRef"。helper函数非常简单,只需调用Module.parameters() 并在每个参数上创建一个本地'RRef'。

def _parameter_rrefs(module):
    param_rrefs = []
    for param in module.parameters():
        param_rrefs.append(RRef(param))
    return param_rrefs

然后,由于RNNModel包含三个子模块,我们需要调用 _parameter_rrefs 三次,并将其封装到另一个helper函数中。

class RNNModel(nn.Module):
    ...
    def parameter_rrefs(self):
        remote_params = []
        # get RRefs of embedding table
        remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref))
        # create RRefs for local parameters
        remote_params.extend(_parameter_rrefs(self.rnn))
        # get RRefs of decoder
        remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref))
        return remote_params

在 trainer 之中,使用如下来生成分布式优化器,这样就把远端的一些参数作为优化对象。

# setup distributed optimizer
opt = DistributedOptimizer(
    optim.SGD,
    model.parameter_rrefs(),
    lr=0.05,
)

我们最后拓展如下:

  • (1) RNNModel 的 emb_table_rref 成员变量指向参数服务器上的EmbeddingTable。
  • (2) RNNModel 的 decoder_rref 成员变量指向参数服务器上的Decoder。
  • (3) RNNModel 的 rnn 成员变量指向本地的LSTM。
  • DistributedOptimizer 内部的三个待优化变量分别指向:4) 参数服务器上的EmbeddingTable 的 参数,5) 参数服务器上的Decoder 的参数,6) 本地LSTM的参数。

分别对应下图上的数字。

                 torch.multiprocessing.spawn
                            +
                            |
                            |
            +---------------+--------------------+
            |                                    |
            |                                    |
            v                                    v
  +---------+------------+ +---------------------+----------------------------------------+
  |"ps"        rank = 0  | | "trainer"                                         rank = 1   |
  |                      | |                                                              |
  |                      | |   model = rnn.RNNModel('ps')                                 |
  |                      | |                                                              |
  |  +---------------+   | |   +---------------------------------------+                  |
  |  |EmbeddingTable |   | |   | RNNModel                              |                  |
+--->+               |   | | 1 |                                       |                  |
| |  |               +<------------+ self.emb_table_rref               |    +------+      |
| |  +---------------+   | |   |                            3          |    |LSTM  |  6   |
| |                      | |   |     self.rnn +---------------------------->+      +<---+ |
| |  +---------------+   | | 2 |                                       |    |      |    | |
| |  |Decoder        +<------------+ self.decoder_rref                 |    +------+    | |
| |  |               |   | |   |                                       |                | |
| |  |               |   | |   +---------------------------------------+                | |
| |  |               |   | |                                                            | |
| |  +------+--------+   | |   forward() {                                              | |
| |         ^            | |       emb = _remote_method(EmbeddingTable.forward, input)  | |
| |         |            | |       output, hidden = self.rnn(emb, hidden)               | |
| |         |            | |       decoded = _remote_method(Decoder.forward, output)    | |
| |         |            | |   }                                                        | |
| +----------------------+ |                                                            | |
|           |              |   dist_autograd.backward(context_id, [loss])               | |
|           |              |                                                            | |
| 5         | 4            |  +------------------------------------------------------+  | |
|           |              |  | DistributedOptimizer                                 |  | |
|           |              |  |                                                      |  | |
|           |              |  |     remote_optimizers = [                            |  | |
+-------------------------------------------------------+ optim_rref1,               |  | |
            |              |  |                           optim_rref2+------------------+ |
            +-------------------------------------------+ optim_rref3                |    |
                           |  |                                                      |    |
                           |  |                          ]                           |    |
                           |  |     step(context_id)                                 |    |
                           |  +------------------------------------------------------+    |
                           +--------------------------------------------------------------+

手机如下:

4.4 比对

因为前面提到:分布式模型并行训练看起来非常类似于单机模型并行训练。主要区别在于用RPC函数替换 Tensor.to(device)。我们用GPU替代参数服务器,把上图大致修改下做一下对比,可能不是非常确切,但是大家可以看出来分布式训练的关键点。

  +----------------------+ +-------------------------------------------------------------+
  | GPU                  | | CPU                                                rank = 0 |
  |                      | |                                                             |
  |                      | |   model = rnn.RNNModel()                                    |
  |                      | |                                                             |
  |  +---------------+   | |   +---------------------------------------+                 |
  |  |EmbeddingTable |   | |   | RNNModel                              |                 |
+--->+               |   | | 1 |                                       |                 |
| |  |               +<------------+ self.emb_table_rref               |   +------+      |
| |  +---------------+   | |   |                            3          |   |LSTM  |  6   |
| |                      | |   |     self.rnn +--------------------------->+      +<---+ |
| |  +---------------+   | | 2 |                                       |   |      |    | |
| |  |Decoder        +<------------+ self.decoder_rref                 |   +------+    | |
| |  |               |   | |   |                                       |               | |
| |  |               |   | |   +---------------------------------------+               | |
| |  |               |   | |                                                           | |
| |  +------+--------+   | |   forward() {                                             | |
| |         ^            | |       emb = EmbeddingTable.forward(input)                 | |
| |         |            | |       output, hidden = self.rnn(emb, hidden)              | |
| |         |            | |       decoded = Decoder.forward(output)                   | |
| |         |            | |   }                                                       | |
| +----------------------+ |                                                           | |
|           |              |   loss.backward()                                         | |
|           |              |                                                           | |
| 5         | 4            |  +----------------------------------------+               | |
|           |              |  | Optimizer                              |               | |
|           |              |  |                                        |               | |
|           |              |  |          param_groups = [              |               | |
+-------------------------------------------------------+ optim_rref1, |               | |
            |              |  |                                        |               | |
            |              |  |                           optim_rref2+-----------------+ |
            |              |  |                                        |                 |
            +-------------------------------------------+ optim_rref3  |                 |
                           |  |                          ]             |                 |
                           |  |          step()                        |                 |
                           |  |                                        |                 |
                           |  +----------------------------------------+                 |
                           +-------------------------------------------------------------+

手机如下:

0xFF 参考

GETTING STARTED WITH DISTRIBUTED RPC FRAMEWORK

posted @ 2021-12-13 09:39  罗西的思考  阅读(845)  评论(0编辑  收藏  举报