[源码解析] 深度学习流水线并行 GPipe(3) ----重计算

[源码解析] 深度学习流水线并行 GPipe(3) ----重计算

0x00 摘要

GPipe是一个基于 Lingvo (Lingvo 是 Google 基于 TensorFlow 二次开发的重点针对序列模型的框架)开发的,支持超大规模模型的神经网络训练并行库,本文介绍其重计算功能,同时可以和其他实现一起印证。

本系列其他文章如下:

[源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现

[源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积

0x01 概述

1.1 前文回顾

前文提到,目前分布式模型训练有几个必要并行技术:

  • 流水并行,尤其是如何自动设定流水;
  • 梯度累加(Gradient Accumulation);
  • 后向重计算;
  • 1F1B 策略(我们将采用PipeDream分析);

在前文中,我们介绍了Gpipe如何实施流水线并行技术,以及梯度累积。

流水并行存在一个问题:显存占用太大。如果每个 micro-batch 前向计算的中间结果(activation)被后向计算所消费,则需要在显存中缓存 n 份(梯度累加的次数)完整的前向 activation。这时就不得不用另一项重要的技术:重计算(Checkpointing)。

本文以论文"Training deep nets with sublinear memory cost"为基础,对于 pytorch 和 Gpipe 源码 进行分析,期望可以对 “Gradient checkpointing”技术有一个具体的理解。

1.2 Gradient checkpointing

2016年,陈天奇团队提出了亚线性内存优化相关的 "gradient/activation checkpointing(后向重计算)"等技术,旨在降低深度学习训练过程中的中间激活(activation)带来的显存占用。Checkpointing技术属于亚线性内存优化的一种,除此之外还有CPU offload等技术(CPU offload在微软Deepspeed框架中被广泛使用)。

梯度检查点是一种减少深度神经网络训练时内存消耗的系统性方法,具体是在反向传播中,针对每个设定为检查点的段,通过重新运行前向传播段来实现的:

  • 梯度检查点方法集中在减少降低存储中间结果(特征图)和梯度的内存开销,因为在许多常见的深度网络之中,与模型参数相比,中间结果要大得多。
  • 梯度检查点是一种以时间(算力)换空间(显存)的方法,通过减少保存的激活值压缩模型占用空间,但是在计算梯度时必须重新计算没有存储的激活值,即需要花两倍的前向传播计算时间。
  • 具体来说,就是设置一些梯度检查点,检查点之外的中间结果先释放掉,将来在反向传播的过程中如果发现前向结果不在显存中,就找到最近的梯度检查点再进行前向计算,恢复出被释放的张量。

0x02 背景知识

2.1 求导如何工作

此处借鉴了 训练时显存优化技术——OP合并与gradient checkpoint 的思路。

DNN模型由一系列不同类型的层组成(例如卷积层,全连接层,池化层)。

反向传播的关键是“自动链式求导”,但实际上BP在这个基础上也加入了一点动态规划机制。一般的BP包含以下两个步骤:

  • 前向传导。以图像分类为例,当前模型首先对一小部分训练样本(也称为minibatch)进行预测。这个过程被称为向前传导。
    • 为了进行预测,来自小批量的输入数据被输入到模型的第一层。
    • 然后,每一层在其输入上计算一个函数,为下一层生成输出。前向传导记录以下两个值:中间结点的输出值,输出值关于输入值的梯度。
    • 最后一层的输出是类预测。基于模型的预测标签和每个图像的实际标签,输出层计算损失(或错误)。
  • 反向传播梯度计算。反向传播就是一个计算网络最终输出值关于本层输出的梯度的过程。即,从输出开始,反向传播梯度值,计算输出值对于每一个中间变量的梯度,并保存。每层计算 前一层的误差,和 所有相关层的权重更新(损失梯度),这将使模型的预测朝着所需的输出移动。

在梯度回传的过程中需要用到节点的输出值,但是在反向传播进行梯度计算的时候,BP不会进行重复计算,其原因就是在前向传导的时候,进行了中间变量的存储,也就是每个中间节点的输出值。BP不断地反向传播梯度,并保存中间梯度,直到计算图的所有中间值以及初始值的梯度被求解完毕。

我们看看反向传播是如何工作的。

所谓自动求导框架实际上是“半自动”的:它并非直接求出一个复杂函数导数的解析形式,而是通过构建计算图和预先写好的基础函数的求导规则,结合链式求导法则实现的自动求导

我们假设一个函数为例进行说明,其表达式如下:

f(x) = x * (x + 1)

通过简单的数学推导得到其梯度的解析式为f'(x) = x + 1 + x;先把这个结果放一边,看看自动求导框架是如何一步步求出这个结果的,画出计算图如下:

                       +---------+
                       |         |
               +------>+  x + 1  +----+
               |       |         |    | 3
             2 |       +---------+    |
               |                      |
               |                      v
         +-----+--+                  ++------+
         |        |                  |       |
+------> |    x   +----------------> |   +   +---------->
         |        |         1        |       |
         +--------+                  +-------+

在计算图上,反向传播先经过乘法运算,根据上面的求导规则:

  • 路径1上的梯度为 x + 1
  • 路径3上的梯度为 x
  • 路径3再反向传播要经过路径2,除了其梯度为 x + 1 之外,还要乘上 路径2的梯度 1
  • 路径2和路径1汇聚到一起,所以最终的梯度为 x + 1(路径1)+ 1 * x(路径2)= x + 1 + x,刚好等于我们用数学公式计算出来的结果;

自动求导框架正是依靠这些基础的规则和链式求导法则在高效准确的运作。

在绝大多数神经网络的训练过程中,在计算反向传播时,前向传播过程中得到的一些中间变量非常有用(为了方便求导)。在实际操作中,最好代码实现对于这些中间变量的缓存,这样在反向传播的时候也能用上它们。于是显存占用的大头就是中间结果,也就是所谓的“特征图”。对于本文,x 就是前一层输出的中间结果(特征图)。

在适用乘法的求导规则时,要求我们要事先保留下中间结果 x 和 x+1。注意框架定义的乘法及其求导规则是通用规则,乘法的左右两边完全可能是不相关的两个值,所以必须同时保留下来。就是说,x + 1 在其他函数中,可能是 ( x + y ) + z ....,也可能包含其他输入变量,所以无法通过 + 1 这样简单的算式由一个输入 x 计算出来

在不考虑框架自身优化的情况下,显存占用就包括了一个 x 和 一个 x + 1,注意x可不是一个单独的数值,而是类似32x32x128这样大小的特征图。

2.2 梯度Checkpoint

如前一节所述,神经网络的原始方式中:

  • 在forward函数中,每层的激活函数值计算之后需要保存下来,因为它们需要在后向传播的计算中被消费。
  • 在backward时,根据损失函数值和该层对应的激活函数值计算梯度。
  • 因此,我们需要在显存中缓存 n 份(梯度累加的次数)完整的前向 activation。也就是说,这种情况下显存占用与 层数成正比。

因此,目前流水并行存在一个问题:显存占用太大。

是否可以不存储激活值?比如在backward时,需要激活函数值的时候重新进行forward就可以了。

假如我们一个都不存储,都通过forward重新计算?那么在大模型中这样消耗的时间太大。所以我们可以选用折中的方式,比如只存部分层的激活函数值。当backward需要激活函数值的时候,取最近的激活值就行。所以就引入了一项重要的技术:重计算(Checkpointing)。

2.3 论文内容

2.3.1 主要论文

Gpipe 的 Checkpointing 主要思路来自以下两篇论文:

  • Andreas Griewank and Andrea Walther. Algorithm 799: revolve: an implementation of check- pointing for the reverse or adjoint mode of computational differentiation. ACM Transactions on Mathematical Software (TOMS), 26(1):19–45, 2000.
  • Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174, 2016.

主要思路是用算力换内存(计算换显存,反向求导时需要的中间结果从 checkpoint 重新计算),以及用带宽换显存。

2.3.2 论文 Training Deep Nets with Sublinear Memory Cost

2.3.2.1 主要思路

我们主要来看这篇论文。

Checkpointing 是陈天奇在2016年发表的论文 Training Deep Nets with Sublinear Memory Cost 中提到的,也称之为亚线性内存优化。亚线性内存优化有两种思路,Checkpointing 和 CPU offload:

  • Checkpointing 的核心思想 是在前向网络中标记少量的 Tensor (被 Checkpointing 的 Tensor ),前向计算就只会保留这些被标记的 Tensor, 其余的前向的 activation,会通过在反向传播中根据 Checkpointing 的 Tensor 临时重新计算一遍前向得到。这样就使得大量的 activation 不需要一直保存到后向计算,有效减少了大量 Tensor 的生命周期,使得内存复用效率大幅提升。
  • CPU offload 的思路类比于计算机操作系统中的“虚拟内存”技术(将不常用的内存临时换入换出到磁盘上,从而增加内存总量),在深度学习中,GPU 显存(Device Memory)的特点是昂贵、高速且容量小,而 CPU 主存(Host Memory)的特点是便宜、相对低速和大容量;那么将前向计算中的一些暂时用不到的 activation 临时换出到 CPU 主存上,等到反向计算需要时再换入到 GPU 显存里,通过这种方式也可以节省显存。

两种亚线性内存优化通过不同的方式达到了显存优化:Checkpointing 是通过额外的计算开销换显存, CPU offload 通过额外的传输开销换显存。

2.3.2.2 Checkpointing 优化

上图展示了做 Checkpointing 之前和之后的计算图对比。

左面灰色的是网络配置。

中间 Normal Gradient Graph 是普通网络的前向后向传播流程。

右面 Memory Optimized Gradient Graph 就是应用了 gradient-checkpoint 的结果。为了进一步减少内存,会删除一些中间结果,并在需要时从额外的前向计算中恢复它们。

  • 首先,神经网络分为几个部分(右面图中就分成了三段),该算法只记住每一段的输出,并在每一段中删除所有中间结果。
  • 其次,在反向传播阶段,我们可以通过从最近的记录结果向前运行来重新计算丢弃的中间结果。
  • 因此,我们只需支付存储每个段的输出的内存成本加上在每个段上进行反向传播的最大内存成本。

所以gradient-checkpoint 就是并非是不需要中间结果,而是有办法在求导过程中实时的计算出之前被舍弃掉的中间结果

重计算并不是单独为流水并行设计的,并且之前大多使用在单卡或者数据并行场景下。但这个优化在流水并行下就非常关键,因为它使得前向不需要缓存所有的 activation,而只需要缓存非常少个数的(比如一层 Transformer Layer 只会缓存一个 )、被 checkpoint 的特定 Tensor ,从而大大节省了流水并行下的显存开销。

0x03 OpenAI

在OpenAI 提出的gradient-checkpoint 就是论文Training Deep Nets with Sublinear Memory Cost思路的实现,因为其文档比较齐全(https://github.com/openai/gradient-checkpointing),我们可以学习借鉴下。

总体思路是:在神经网络中间设置若干个检查点(checkpoint),对于中间结果feature map,每隔 sqrt(n)保留一个检查点。检查点以外的中间结果全部舍弃,反向传播求导数的时间,需要某个中间结果时,从最近的检查点开始计算,这样既节省了显存,又避免了从头计算的繁琐过程。

3.1 计算图

对一个简单的 n 层前馈神经网络,获取梯度的计算图如下所示:

具体如下:

  • 神经网络的层级激活值对应于 f 标记的节点,且在正向传播过程中,所有这些节点需要按顺序计算。
  • 损失函数对激活值和这些层级参数的梯度使用 b 节点标记,且在反向传播过程中,所有这些节点需要按逆序计算。
  • 计算 f 节点的激活值是进一步计算 b 节点梯度的前提要求,因此 f 节点在前向传播后会保留在内存中。
  • 只有当反向传播执行地足够远以令计算对应的梯度不再需要使用后面层级的激活值或 f 的子节点时,这些激活值才能从内存中清除。这意味着简单的反向传播要求内存与神经网络的层级数成线性增长关系。

3.2 重计算

简单的反向传播已经是计算最优的了,因为每个节点只需要计算一次。然而,如果我们愿意重新计算节点,那么我们可以节省大量的内存。当我们需要节点的激活值时,我们可以简单地重计算前向传播的节点激活值。我们可以按顺序执行计算,直到计算出需要使用激活值进行反向传播的节点。

使用这一策略,需要令计算梯度的内存在神经网络层的数量 n 上是稳定的,且 n 在内存方面是最优的。但是要注意,节点的计算数量现在扩展了 n^2,相比于之前的 n。n 个节点中的每一个被再计算 n 次。因此计算图变得很慢以计算深度网络,使得这一方法不适用于深度学习。

3.3 策略

为了在内存与计算之间取得平衡,我们需要一个策略允许节点被再计算,但是这种再计算不会发生很频繁。这里我们使用的策略是把神经网络激活的一个子集标记为一个节点。紫色的节点表示在给定的时间内需要储存在内存中。

这些检查点节点在前向传播后保留在内存中,而其余节点最多只会重新计算一次。在重新计算后,非检查点节点将保留在内存中,直到不再需要它们来执行反向传播。对于简单的前馈神经网络,所有神经元的激活节点都是由正向传播定义的连接点或图的分离点。这意味着我们在反向传播过程中只需要重计算 b 节点和最后检查点之间的节点,当反向传播达到了我们保存的检查点节点,那么所有从该节点开始重计算的节点在内存中都能够移除。

3.4 过程

首先,我们设定了两个checkpoint,图上第一行左面两个紫色,注意,右面第一个紫色是输入。

其次,正向传播已经完成,开始反向传播,就是从下面一行紫色1号开始反向传播。

第三,来到了下面一行的紫色2号,它依赖于上面的紫色3号来计算(回忆一下,后向传播计算需要前向计算的输出),此紫色3号是checkpoint,在内存中存在,所以正常执行反向传播

第四,来到了下面一行的白色 4 号,它依赖于上面的紫色 5 号来计算,5 号不是一个checkpoint,不在内存之中,需要重它前面的checkpoint开始计算,即从紫色 7 号开始计算。计算出来一个新的checkpoint,同时可以删除上面一行原有紫色 5 号,因为不需要了。

第五,计算出下面的新紫色 4 号,从而继续后向计算。

因为涉及到自动生成checkpoint,OpenAI这部分代码比较晦涩鬼畜,所以这里不进行分析,如果有兴趣的同学可以自行学习。

0x04 Pytorch 实现

我们接下来用Pyorch来看看。

4.1 基础知识

4.1.1 Variable & Function

在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法。它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的。

autograd.Variable 是autograd中最核心的类。 它包装了一个Tensor,并且几乎支持所有在其上定义的操作。一旦完成了你的运算,你可以调用 .backward()来自动计算出所有的梯度。

另一个对autograd的实现非常重要的类是Function,Function简单说就是对Variable的运算,如加减乘除,relu,pool等。但它不仅仅是简单的运算。与普通Python或者numpy的运算不同,Function是针对计算图,需要计算反向传播的梯度。因此他不仅需要进行该运算(forward过程),还需要利用cache保留前向传播的输入(为计算梯度),并支持反向传播计算梯度。

Pytorch是利用Variable与Function来构建计算图的。回顾下Variable,Variable就像是计算图中的节点,保存计算结果(包括前向传播的激活值,反向传播的梯度),而Function就像计算图中的边,实现Variable的计算,并输出新的Variable。

总结,Function与Variable构成了pytorch的自动求导机制,它定义的是各个Variable之间的计算关系。

备注:最新 PyTorch 代码之中,已经用把 Function 修改为 Node 类,应该是为了更好的表示计算图中节点的概念。

4.1.2 Function进一步理解

我们可以使用autograd.Function类来自定义一个模型、一个层、一个激活函数、一个损失函数,就更加好理解了,实际上本质上来说都是一个函数,只分这个函数是简单还是复杂。

4.2 普通模式

这部分代码位于torch/utils/checkpoint.py。pytorch是需要用户指定checkpoint,因此实现相对简单很多。

4.2.1 封装

在 torch/utils/checkpoint.py 之中,对checkpoint有了一个封装,该注释非常值得我们阅读,我们深入学习一下。

  • Checkpointing 本质就是用计算换内存。

  • Checkpointing 存储用于后向计算所需要的整个计算图的全部中间激活值,而是在反向传播中重新计算它们。

  • 在前向传播过程中,Checkpointing 参数 function 是运行在 torch.no_grad 模式,这样就不会计算中间激活值了。相反,向前传递保存输入元组和function参数。

  • 在向后传递中,保存的输入和function被取出,function将再次被计算,这次会跟踪中间激活值,然后

    使用这些激活值计算梯度。

def checkpoint(function, *args, **kwargs):
    r"""Checkpoint a model or part of the model

    Checkpointing works by trading compute for memory. Rather than storing all
    intermediate activations of the entire computation graph for computing
    backward, the checkpointed part does **not** save intermediate activations,
    and instead recomputes them in backward pass. It can be applied on any part
    of a model.

    Specifically, in the forward pass, :attr:`function` will run in
    :func:`torch.no_grad` manner, i.e., not storing the intermediate
    activations. Instead, the forward pass saves the inputs tuple and the
    :attr:`function` parameter. In the backwards pass, the saved inputs and
    :attr:`function` is retrieved, and the forward pass is computed on
    :attr:`function` again, now tracking the intermediate activations, and then
    the gradients are calculated using these activation values.

    The output of :attr:`function` can contain non-Tensor values and gradient
    recording is only performed for the Tensor values. Note that if the output
    consists of nested structures (ex: custom objects, lists, dicts etc.)
    consisting of Tensors, these Tensors nested in custom structures will not
    be considered as part of autograd.

    Args:
        function: describes what to run in the forward pass of the model or
            part of the model. It should also know how to handle the inputs
            passed as the tuple. For example, in LSTM, if user passes
            ``(activation, hidden)``, :attr:`function` should correctly use the
            first input as ``activation`` and the second input as ``hidden``
        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.
        args: tuple containing inputs to the :attr:`function`

    Returns:
        Output of running :attr:`function` on :attr:`*args`
    """
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

    return CheckpointFunction.apply(function, preserve, *args)

4.2.2 处理设备

因为pytorch无法知道向前传播函数是否会把一些参数移动到不同的设备上,这就需要一些逻辑来保存为这些设备保存RNG状态。虽然可以为所有可见设备保存/恢复所有的RNG状态,但是这样在大多数情况下是一种浪费,因此作为折中,pytorch只是针对所有的张量参数的设备进行保存RNG状态。

def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
    # the conditionals short-circuit.
    fwd_gpu_devices = list(set(arg.get_device() for arg in args
                               if isinstance(arg, torch.Tensor) and arg.is_cuda))

    fwd_gpu_states = []
    for device in fwd_gpu_devices:
        with torch.cuda.device(device):
            fwd_gpu_states.append(torch.cuda.get_rng_state())

    return fwd_gpu_devices, fwd_gpu_states


def set_device_states(devices, states) -> None:
    for device, state in zip(devices, states):
        with torch.cuda.device(device):
            torch.cuda.set_rng_state(state)

4.2.3 核心逻辑

CheckpointFunction 继承了torch.autograd.Function。

我们可以对Function进行拓展,使其满足我们自己的需要,而拓展就需要自定义Function的forward运算,以及对应的backward运算,同时在forward中需要通过保存输入值用于backward。

  • forward函输入tensor,计算输出tensor。

    • 在前向传播过程中,Checkpointing 参数 function 是运行在 torch.no_grad 模式,这样就不会计算中间激活值了。
    • 向前传递保存输入元组和function参数。
    • 对于CheckpointFunction来说,还是需要在forward之中存储一些另外的信息(就是上面说的 rng 信息),以供后向传播时候计算使用。
    • 进行前向传播返回激活值。
  • backward函数接收相对于某个标量值的输出张量的梯度,并且计算关于该相同标量值的输入张量的梯度。

    • 在向后传递中,保存的输入和function被取出。
    • function将再次被计算,这次会跟踪中间激活值,然后使用这些激活值计算梯度。
"""
我们可以通过建立torch.autograd的子类来实现我们自定义的autograd函数,
并完成张量的正向和反向传播。
"""
class CheckpointFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        """
        在forward函数中,接收包含输入的Tensor并返回包含输出的Tensor。
        ctx是环境变量,用于提供反向传播是需要的信息。我们可以使用上下文对象来缓存对象,以便在反向传播中使用。可通过ctx.save_for_backward方法缓存数据,save_for_backward只能传入Variable或是Tensor的变量。
        """
        check_backward_validity(args)
        # 保存前向传播函数
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            # 存储前向传播时候的设备状态
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)

        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = [] 
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args): # 存储输入数值
            if torch.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)

        # `saved_for_backward`是会保留此input的全部信息, 并避免in-place操作导致的input在backward被修改的情况. 它是将函数的输入参数保存起来以便后面在求导时候再使用,起前向反向传播中协调作用。      
        ctx.save_for_backward(*tensor_inputs)

        with torch.no_grad():
            outputs = run_function(*args) # 进行前向传播
        return outputs

"""
在反向传播中,我们接收到上下文对象和一个张量,
其包含了相对于正向传播过程中产生的输出的损失的梯度。
我们可以从上下文对象中检索缓存的数据,
并且必须计算并返回与正向传播的输入相关的损失的梯度。
"""      
    # 自动求导是根据每个op的backward创建的graph来进行的
    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
                " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
                " argument.")
        # Copy the list to avoid modifying original list.
        inputs = list(ctx.inputs)
        tensor_indices = ctx.tensor_indices
        tensors = ctx.saved_tensors # 获取前面保存的参数,也可以使用self.saved_variables

        # Fill in inputs with appropriate saved tensors.
        for i, idx in enumerate(tensor_indices): # 利用存储的张量重新设置input
            inputs[idx] = tensors[i]

        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrounding state
        # when we're done.
        # 存储目前rng状态,模拟前向传播状态,最后恢复目前状态
        rng_devices = [] 
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state) # 恢复前向传播时候的设备状态
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            detached_inputs = detach_variable(tuple(inputs))
            with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
                # 利用前向传播函数再次计算
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)

        # run backward() with only tensor that requires grad
        outputs_with_grad = [] # 激活值
        args_with_grad = [] # 梯度
        # 从前向传播计算的结果中筛选需要传播的张量
        for i in range(len(outputs)): 
            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
                outputs_with_grad.append(outputs[i])
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError(
                "none of output has requires_grad=True,"
                " this checkpoint() is not necessary")
            
        # 开始后向传播    
        torch.autograd.backward(outputs_with_grad, args_with_grad)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in detached_inputs)

        return (None, None) + grads

4.3 Pipeline模式

我们接下来看看 流水线模式如何进行 Checkpoint。

Pytorch 流水型并行模式是受到了GPipe的启发,在其注释之中有提到。

通过CheckpointFunction,pytorch可以做到把重计算和递归反向传播合并到一个自动求导函数中,因此当梯度到达时,重计算就会开始。但是在流水线模式中,为了缩减GPU idle时间,重计算需要发生在梯度到达之前进行(因为重计算其实和梯度无关,重计算可以在梯度到来之前进行以获得激活值,等后向传播的梯度来了之后,再集合激活值进行自己的梯度计算)。

为了解决这个问题,pytorch引入了两个自动求导函数:class:Recompute and class:Checkpoint,分别代表重计算和递归反向传播就是把普通模式下的 CheckpointFunction 分离成两个阶段,这样用这两个函数就可以控制自动求导引擎和CUDA。具体说就是在class:Recompute and class:Checkpoint之间插入CUDA同步,这样把class:Checkpoint 推迟到梯度完全拷贝结束。

分开段,就可以多个流水线stage并行了。

4.3.1 样例

我们可以先看看 test/distributed/pipeline/sync/test_checkpoint.py 这个代码。

其通过log的巧妙打印,可以让我们看出来运行时候,checkpoint在前向后向传播之中的使用。

timeline 最后结果是 ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"],

其中两两一组,分别对应了 forward pass ,Checkpoint(Log[b]),Checkpoint(Log[a])。

@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
    # Copied from https://github.com/pytorch/pytorch/pull/18568.
    timeline = []

    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, name, x):
            ctx.name = name
            timeline.append(f"{name}:forward")
            return x.detach()

        @staticmethod
        def backward(ctx, grad_output):
            name = ctx.name
            timeline.append(f"{name}:backward")
            return None, grad_output

    a = torch.rand(1, device=device, requires_grad=True)
    b = torch.rand(1, device=device, requires_grad=True)

    # Increase the next function sequence number.
    _ = a + 1 + 2 + 3 + 4 + 5

    # 这里意味着最后 backward 实际会运行"a:forward", "a:backward"
    a = checkpoint(partial(Log.apply, "a"), a)

    a, phony = fork(a)
    b = join(b, phony)

    # 这里意味着最后 backward 实际会运行"b:forward", "b:backward"
    b = checkpoint(partial(Log.apply, "b"), b)

    c = torch.cat((a, b))

    out = c.sum()

    #                        +--> {a} --Checkpoint(Log)--> {a}
    # {out} --Sum--> {c} --Cat     ^-----------------------------+
    #                        +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
    out.backward()

    assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
    #    |----------------------|  |-----------------------|  |-----------------------|
    #          forward pass            Checkpoint(Log[b])         Checkpoint(Log[a])

4.3.2 共享变量

class:Recompute and class:Checkpoint之间具体是通过Context这个上下文来进行共享变量的保存。

# Types for shared memory between Checkpoint and Recompute.

Recomputed = Tuple[TensorOrTensors, Tensors]  # (output, input_leaf)
RNGStates = Tuple[Tensor, Optional[Tensor]]  # (cpu_rng_state, gpu_rng_state)

class Context:
    """The common interface between the :class:`Checkpoint` and
    :class:`Recompute` context.
    """

    recomputed: Deque[Recomputed]
    rng_states: Deque[RNGStates]
    function: Function
    input_atomic: bool

    saved_tensors: Tuple[Tensor, ...]

    def save_for_backward(self, *tensors: Tensor) -> None:  # pragma: no cover
        pass

4.3.3 rng state

根据运行时的不同,RNG状态可能会产生不同的性能影响,所以需要在每个检查点期间存储当前设备的RNG状态,在重计算之前恢复当前设备的RNG状态。

save_rng_states 和 restore_rng_states 两个方法分别用来存取 RNG 状态。

def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
    """:meth:`Checkpoint.forward` captures the current PyTorch's random number
    generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state = torch.get_rng_state()

    gpu_rng_state: Optional[Tensor]
    if device.type == "cuda":
        gpu_rng_state = torch.cuda.get_rng_state(device)
    else:
        gpu_rng_state = None

    rng_states.append((cpu_rng_state, gpu_rng_state))


@contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state, gpu_rng_state = rng_states.pop()

    gpu_devices: List[torch.device] = []
    if device.type == "cuda":
        gpu_devices.append(device)

    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield

4.3.4 Checkpoint

Checkpoint 和下面的 Recompute 就是把普通模式下的 checkpoint 代码分离成两个阶段(forward函数被分成两段,backward 函数也被分成两段),从而可以更好的利用流水线。

class Checkpoint(torch.autograd.Function):
    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> TensorOrTensors:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states

        # 存RNG状态
        save_rng_states(input[0].device, ctx.rng_states)

        ctx.function = function
        ctx.input_atomic = input_atomic
        # 为BP做准备,其实目前没有实现
        ctx.save_for_backward(*input)

        # 进行前向计算
        with torch.no_grad(), enable_checkpointing():
            output = function(input[0] if input_atomic else input)

        return output

    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:  # pragma: no cover
        # 从保存的重计算变量中弹出所需变量
        output, input_leaf = ctx.recomputed.pop() 

        if isinstance(output, tuple):
            tensors = output
        else:
            tensors = (output,)
            
        if any(y.requires_grad for y in tensors):
            tensors = tuple([x for x in tensors if x.requires_grad])
            # 进行自动微分
            torch.autograd.backward(tensors, grad_output)

        grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
        grad_input.extend(x.grad for x in input_leaf)
        return tuple(grad_input)

4.3.5 Recompute

Recompute 就是依据保存的信息,重新计算中间变量。

class Recompute(torch.autograd.Function):
  
    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> Tensor:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states

        ctx.function = function
        ctx.input_atomic = input_atomic
        ctx.save_for_backward(*input)

        return phony

    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]:  
        input = ctx.saved_tensors
        input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)

        # 取出保存的RNG状态,进行前向计算,得到中间变量
        with restore_rng_states(input[0].device, ctx.rng_states):
            with torch.enable_grad(), enable_recomputing():
                output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)

        # 保存变量,为Checkpoint使用
        ctx.recomputed.append((output, input_leaf))

        grad_input: List[None] = [None, None, None, None, None]
        grad_input.extend(None for _ in ctx.saved_tensors)
        return tuple(grad_input)

4.3.6 Pipeline

4.3.6.1 Task

我们首先要看看 Task 类。代码位于:torch/distributed/pipeline/sync/worker.py。

由注释可知,Task 就是用来在一个分区上计算一个micro-batch。

compute可以在worker线程内被并行执行。

finalize 应该在compute结束之后被执行。

class Task:
    """A task represents how to compute a micro-batch on a partition.

    It consists of two parts: :meth:`compute` and :meth:`finalize`.
    :meth:`compute` should be executed in worker threads concurrently.
    :meth:`finalize` should be executed after when worker threads complete to
    execute :meth:`compute`.

    :meth:`compute` might be boosted by worker threads. Because it produces
    several CUDA API calls by user code. In PyTorch, parallel CUDA API calls
    are not serialized through GIL. So more than one CUDA API call can be
    produced at the same time.

    """

    def __init__(
        self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]],
    ) -> None:
        self.stream = stream
        self._compute = compute
        self._finalize = finalize
        self._grad_enabled = torch.is_grad_enabled()

    def compute(self) -> Batch:
        with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
            return self._compute()

    def finalize(self, batch: Batch) -> None:
        if self._finalize is None:
            return
        with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
            self._finalize(batch)
4.3.6.2 compute

这里说的是 Pipeline 类的 compute 函数。

Pipeline 的逻辑如其注释所示(PyTorch的注释真的很翔实)。重点是 Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) 这里设置了如何进行checkpoint。

可以看到,这里会将 recompute 方法设置为 Task 的 finalize 方法,然后会计划重计算。

class Pipeline:
    """The pipeline parallelism for Pipe."""
    
    def compute(
        self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
    ) -> None:
        """Runs tasks with synchronization to copy streams."""
        partitions = self.partitions
        devices = self.devices
        copy_streams = self.copy_streams
        checkpoint_stop = self.checkpoint_stop

        # Disable checkpointing if in eval mode.
        if not self.partitions[0].training:
            checkpoint_stop = 0

        n = len(partitions)
        streams = [current_stream(d) for d in devices]
        exc_info: Optional[ExcInfo] = None

        # With checkpointing, the autograd graph looks like this diagram:
        # ┌─────┸──────┐
        # │    Copy    │
        # └─────┰──────┘   (fence)
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        #       ┃          (compute)
        # ┌─────┸──────┐
        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │ Checkpoint │ [2] Compute a partition within checkpointing.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
        # └─────┰──────┘
        #       ┠ ─ ─ ─ ┐
        #       ┃ ┌─────┴─────┐
        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
        #       ┃ └─────┬─────┘
        #       ┠ ─ ─ ─ ┘
        #       ┃
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        # ┌─────┸──────┐   (fence)
        # │    Copy    │
        # └─────┰──────┘
        for i, j in schedule:
            batch = batches[i]
            partition = partitions[j]

            # Synchronize with the copied input. ([1] in the diagram)
            if j != 0:
                _wait(batch, copy_streams[j][i], streams[j])

            # Determine whether checkpointing or not.
            checkpoint = i < checkpoint_stop
            if checkpoint:

                def function(
                    input: TensorOrTensors,
                    partition: nn.Sequential = partition,
                    skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                    chunk_id: int = i,
                    part_id: int = j,
                ) -> TensorOrTensors:
                    with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
                        return partition(input)

                # 这里进行处理
                chk = Checkpointing(function, batch)
                # 分别设置了chk.checkpoint 和 chk.recompute
                task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
                del function, chk

            else:

                def compute(
                    batch: Batch = batch,
                    partition: nn.Sequential = partition,
                    skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                    chunk_id: int = i,
                    part_id: int = j,
                ) -> Batch:
                    with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
                        return batch.call(partition)

                task = Task(streams[j], compute=compute, finalize=None)
                del compute

            # Compute tasks in parallel. ([2] in the diagram)
            self.in_queues[j].put(task) # 将task插入到 pipeline的queue,这样可以并行。

        for i, j in schedule: 
            ok, payload = self.out_queues[j].get()

            # Hold the first exception.
            if exc_info is not None:
                continue
            elif not ok:
                exc_info = cast(ExcInfo, payload)
                continue

            # 取出 task    
            task, batch = cast(Tuple[Task, Batch], payload)

            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)
            if j != n - 1:
                _wait(batch, streams[j], copy_streams[j][i])

            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)
            with use_device(devices[j]):
                task.finalize(batch) # 计划进行重计算

            batches[i] = batch

        # Fail at the first exception.
        if exc_info is not None:
            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

关于 PyTorch 的 Pipeline,后续会有专门系列进行分析。

0x05 Gpipe实现

Gpipe 在反向传播的时候,可以在第 k-th 个 accelerator 上重新计算前向传播函数 F_k。

5.1 API函数 _Rematerialize

首先,我们看看API方法。

在 builder.py 之中有 _Rematerialize 函数,可以用来包装一个需要重新计算的层。

  def _Rematerialize(self, name, body):
    """Forces rematerialization on FProp of the body layer."""
    return builder_layers.RematerializationLayer.Params().Set(
        name=name, body=body)

5.2 包装层 RematerializationLayer

RematerializationLayer 是包装层,其中有:

FProp 就是把被封装层 包装为一个函数 Fn,然后调用 py_utils.RematerializeFn 把 Fn 与 输入变量一起传入。

class RematerializationLayer(base_layer.BaseLayer):
  """A wrapper layer with rematerialization."""

  @classmethod
  def Params(cls):
    p = super().Params()
    p.Define('body', None,
             'The main layer whose FProp will be wrapped by RematerializeFn.')
    return p

  def __init__(self, params):
    super().__init__(params)
    self.CreateChild('body', self.params.body)

  def FProp(self, theta, *xs):
    input_list = theta.body.Flatten() # 得到theta
    theta_len = len(input_list)
    input_list += list(xs) # 得到输入参数
    input_len = len(input_list)

    def Fn(*args): # 包装函数,会调用被封装层的 FProp
      body_theta = theta.body.Pack(args[:theta_len])
      return self.body.FProp(body_theta, *args[theta_len:input_len])

    return py_utils.RematerializeFn(Fn, *input_list) # 调用,执行FProp,并且做Gradient checking

  @classmethod
  def FPropMeta(cls, p, *args): # 就是传播被封装层的信息
    py_utils.CheckShapes(args)
    return p.body.cls.FPropMeta(p.body, *args)

3.2.3 tensorflow gradients 函数

RematerializeFn 调用了 tensorflow gradients 函数 来计算梯度,所以我们需要解释下。

在tensorflow中,gradients 函数可以自动计算函数的梯度。我们只需要设计我们的函数,然后去调用 tf.gradients 函数就可以了。

tf.gradients()的参数如下,其中

  • tf.gradients()实现ysxs求导
  • grad_ys也是一个list,其长度等于len(ys)。这个参数的意义在于对xs中的每个元素的求导权重。
tf.gradients(ys, xs, 
			 grad_ys=None, 
			 name='gradients',
			 colocate_gradients_with_ops=False,
			 gate_gradients=False,
			 aggregation_method=None,
			 stop_gradients=None)

5.4 功能函数 RematerializeFn

RematerializeFn 是最终功能函数,就是调用 fn,并且在反向传播过程中进行rematerializes fn。

def RematerializeFn(fn, *xs):
  """Calls fn and rematerializes fn in the backward pass.

  `fn(*xs) -> ys`, where xs and ys can be a single tensor or a tuple of tensors.

  Args:
    fn: A python function to be rematerialized in the backprop pass.
    *xs: A single tensor or a list/tuple of tensors. `xs` are input args to the
      fn function.

  Returns:
    `fn(*xs)`
  """
  initial_step_seed = GetStepSeed()
  final_step_seed = MaybeGenerateSeedFromScope()

  def Backward(fwd_xs, fwd_ys, d_fwd_ys):
    """The backward function that rematerializes forward outputs."""
    del fwd_ys # 去掉传入的参数,因为在内部需要用备份的Checkpoint来处理
    always_true = tf.random.uniform([]) < 2.0
    # Alternatively, can do this:
    # tf.where(tf.math.is_nan(x),
    #          tf.constant(float('nan'), dtype=x.dtype) * tf.ones_like(x),
    #          x)
    bak_xs = [tf.where(always_true, x, tf.zeros_like(x)) for x in fwd_xs.xs] # 依据Checkpoint来生成 bak_xs
    for dst, src in zip(bak_xs, xs):
      dst.set_shape(src.shape)
    ResetStepSeed(initial_step_seed)
    ys = fn(*bak_xs) # 依据Checkpoint来重新生成ys
    MaybeResetStepSeed(final_step_seed)
    dxs = tf.gradients(ys, bak_xs, grad_ys=d_fwd_ys) # ys 对 bak_xs 求导
    dxs_final = [] # 聚合
    for dx, x in zip(dxs, bak_xs):
      if dx is None:
        dxs_final.append(tf.zeros_like(x))
      else:
        dxs_final.append(dx)
    assert len(dxs_final) == len(bak_xs)
    return NestedMap(
        initial_step_seed=tf.zeros_like(initial_step_seed), xs=dxs_final)

  ys_shapes = []

  # TODO(huangyp, yonghui): Check Forward doesn't use any stateful random ops.
  def Forward(fwd_xs):
    """Forward function plus sanity checks."""
    for dst, src in zip(fwd_xs.xs, xs):
      dst.set_shape(src.shape)
    ResetStepSeed(fwd_xs.initial_step_seed)
    ys = fn(*fwd_xs.xs) # 正常计算
    # Some sanity check.
    assert not GetExtraInputs()
    assert not GetExtraArgs()
    assert not GetExtraVars()
    if isinstance(ys, tuple):
      for y in ys:
        assert isinstance(y, tf.Tensor)
        ys_shapes.append(y.shape)
    else:
      assert isinstance(ys, tf.Tensor)
      ys_shapes.append(ys.shape)
    return ys

  ys = CallDefun(
      Forward,
      NestedMap(initial_step_seed=initial_step_seed, xs=xs),
      bak=Backward)
  if isinstance(ys, tuple):
    for y, s in zip(ys, ys_shapes):
      y.set_shape(s)
  else:
    ys.set_shape(ys_shapes[0])
  # TODO(b/129159299): The ResetStepSeed below is needed to work around this
  # bug, which is a problem with global tensors being shared by different
  # inference graphs. It should be replaced with the new step seed value
  # returned from the Forward function when the bug is fixed.
  MaybeResetStepSeed(final_step_seed)
  return ys

CallDefun定义如下,就是把fwd, back封装起来进行调用。其中,Function 的作用是依据一个callable 构建一个TensorFlow graph function

def CallDefun(fwd, args=None, bak=None, bak_as_function=False, device=None):
  """Wraps fwd in a defun with custom gradient bak and calls it with args.

  Args:
    fwd: A callable xs: Nested Structure -> ys: Nested Structure.
    args: A Nested Structure of tf.Tensor or None.
    bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested
      Structure. The custom backprop function for fwd. bak needs to return
      dcapture if fwd uses any implicitly captured tensors, whose gradients are
      dcapture.
    bak_as_function: Whether to create a TF graph function for bak.
    device: the device on which to run fwd and bak.

  Returns:
    A Nested Structure equivalent to what fwd(args) computes.
  """
  if args is not None:
    args = Transform(tf.convert_to_tensor, args)
  sigs = Function(
      fwd_sig=TensorSpecs(args),
      bak=bak,
      bak_as_function=bak_as_function,
      device=device)(
          fwd=fwd)
  if args is None:
    return sigs()
  else:
    return sigs(args)

至此,GPipe 分析完毕,下一篇开始分析 PipeDream,敬请期待。

0xFF 参考

lingvo框架走读笔记

Tensorflow实现先累加多个minibatch计算的梯度,再反向传播

用tensorflow2实现梯度累积

十倍模型计算时间仅增20%:OpenAI开源梯度替换插件

PipeDream: Fast and Efficient Pipeline Parallel DNN Training

论文解读系列第五篇:微软斯坦福等PipeDream快速训练大规模神经网络

https://cs231n.github.io/neural-networks-3/#gradcheck

https://www.cnblogs.com/geekfx/p/14182048.html

训练时显存优化技术——OP合并与gradient checkpoint

Pytorch笔记04-自定义torch.autograd.Function

PyTorch教程之Autograd

pytorch的自定义拓展之(三)——torch.autograd.Function的简单定义与案例

pytorch的自定义拓展之(二)——torch.autograd.Function完成自定义层

PyTorch 源码解读之 torch.autograd:梯度计算详解

再谈反向传播(Back Propagation)

CS231n课程笔记翻译:反向传播笔记

posted @ 2021-08-30 19:43  罗西的思考  阅读(203)  评论(0编辑  收藏  举报