深度学习基本功——自动微分的正向模式与反向模式:理解JVP与VJP
引言
在现代机器学习和深度学习中,自动微分(Automatic Differentiation, AD)扮演着至关重要的角色。它使得我们能够有效地计算复杂函数的导数,从而在优化过程中更新模型参数。本文将探讨自动微分的两种主要模式:正向模式(forward mode)和反向模式(inverse mode),以及它们对应的雅可比-向量乘法(JVP)和向量-雅可比乘法(VJP)。最后,我们将分析为什么当前深度学习神经网络普遍采用反向模式。
本文参考了以下的自动微分时的资料,这篇博客也算是对自己学习的总结:
正向模式和反向模式
我们先考虑输出是标量的情况,假设我们有以下函数关系:
  
      
       
        
        
          y 
         
        
          = 
         
        
          f 
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          = 
         
        
          a 
         
        
          ( 
         
        
          b 
         
        
          ( 
         
         
         
           c 
          
         
           ( 
          
         
           x 
          
         
           ) 
          
         
        
          ) 
         
        
          ) 
         
        
       
         y=f(\mathbf{x})=\mathbf{a} (\mathbf{b} ( \mathbf{c(\mathbf{x})})) 
        
       
     y=f(x)=a(b(c(x)))
其中:
- x ∈ R n \mathbf{x} \in \mathbb{R}^n x∈Rn 是输入向量
 - a ∈ R m \mathbf{a} \in \mathbb{R}^m a∈Rm
 - b ∈ R k \mathbf{b} \in \mathbb{R}^k b∈Rk
 - c ∈ R p \mathbf{c} \in \mathbb{R}^p c∈Rp
 
根据链式法则,得到导数:
  
      
       
        
         
          
          
            ∂ 
           
          
            y 
           
          
          
          
            ∂ 
           
          
            c 
           
          
         
        
          ∈ 
         
         
         
           R 
          
          
          
            1 
           
          
            × 
           
          
            p 
           
          
         
        
          , 
         
         
         
          
          
            ∂ 
           
          
            c 
           
          
          
          
            ∂ 
           
          
            b 
           
          
         
        
          ∈ 
         
         
         
           R 
          
          
          
            p 
           
          
            × 
           
          
            k 
           
          
         
        
          , 
         
         
         
          
          
            ∂ 
           
          
            b 
           
          
          
          
            ∂ 
           
          
            a 
           
          
         
        
          ∈ 
         
         
         
           R 
          
          
          
            k 
           
          
            × 
           
          
            m 
           
          
         
        
          , 
         
         
         
          
          
            ∂ 
           
          
            a 
           
          
          
          
            ∂ 
           
          
            x 
           
          
         
        
          ∈ 
         
         
         
           R 
          
          
          
            m 
           
          
            × 
           
          
            n 
           
          
         
        
       
         \frac{\partial y}{\partial \mathbf{c}} \in \mathbb{R}^{1 \times p}, \quad \frac{\partial \mathbf{c}}{\partial \mathbf{b}} \in \mathbb{R}^{p \times k}, \quad \frac{\partial \mathbf{b}}{\partial \mathbf{a}} \in \mathbb{R}^{k \times m}, \quad \frac{\partial \mathbf{a}}{\partial \mathbf{x}} \in \mathbb{R}^{m \times n} 
        
       
     ∂c∂y∈R1×p,∂b∂c∈Rp×k,∂a∂b∈Rk×m,∂x∂a∈Rm×n
则有:
  
      
       
        
         
         
           f 
          
         
           ′ 
          
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          = 
         
         
          
          
            ∂ 
           
          
            y 
           
          
          
          
            ∂ 
           
          
            c 
           
          
         
        
          ⋅ 
         
         
          
          
            ∂ 
           
          
            c 
           
          
          
          
            ∂ 
           
          
            b 
           
          
         
        
          ⋅ 
         
         
          
          
            ∂ 
           
          
            b 
           
          
          
          
            ∂ 
           
          
            a 
           
          
         
        
          ⋅ 
         
         
          
          
            ∂ 
           
          
            a 
           
          
          
          
            ∂ 
           
          
            x 
           
          
         
        
       
         f^{\prime}(\mathbf{x})=\frac{\partial y}{\partial \mathbf{c}} \cdot \frac{\partial \mathbf{c}}{\partial \mathbf{b}} \cdot \frac{\partial \mathbf{b}}{\partial \mathbf{a}} \cdot \frac{\partial \mathbf{a}}{\partial \mathbf{x}} 
        
       
     f′(x)=∂c∂y⋅∂b∂c⋅∂a∂b⋅∂x∂a
正向模式自动微分是从输入向输出计算导数的方法。在这一模式中,首先计算输入
  
     
      
       
       
         x 
        
       
      
        \mathbf{x} 
       
      
    x对各个中间变量的导数,然后利用链式法则逐步向输出传播。计算顺序是从右往左,用括号表示如下:
  
      
       
        
         
         
           f 
          
         
           ′ 
          
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          = 
         
         
          
          
            ∂ 
           
          
            y 
           
          
          
          
            ∂ 
           
          
            c 
           
          
         
        
          ⋅ 
         
         
         
           ( 
          
          
           
           
             ∂ 
            
           
             c 
            
           
           
           
             ∂ 
            
           
             b 
            
           
          
         
           ⋅ 
          
          
          
            ( 
           
           
            
            
              ∂ 
             
            
              b 
             
            
            
            
              ∂ 
             
            
              a 
             
            
           
          
            ⋅ 
           
           
            
            
              ∂ 
             
            
              a 
             
            
            
            
              ∂ 
             
            
              x 
             
            
           
          
            ) 
           
          
         
           ) 
          
         
        
       
         f^{\prime}(\mathbf{x})=\frac{\partial y}{\partial \mathbf{c}} \cdot\left(\frac{\partial \mathbf{c}}{\partial \mathbf{b}} \cdot\left(\frac{\partial \mathbf{b}}{\partial \mathbf{a}} \cdot \frac{\partial \mathbf{a}}{\partial \mathbf{x}}\right)\right) 
        
       
     f′(x)=∂c∂y⋅(∂b∂c⋅(∂a∂b⋅∂x∂a))
反向模式自动微分是从输出向输入传播导数的方法。在这一模式中,首先计算输出
  
     
      
       
       
         y 
        
       
      
        y 
       
      
    y对各个中间变量的导数,然后利用链式法则反向传播。计算顺序是从左往右,用括号表示如下:
  
      
       
        
         
         
           f 
          
         
           ′ 
          
         
        
          ( 
         
        
          x 
         
        
          ) 
         
        
          = 
         
         
         
           ( 
          
          
          
            ( 
           
           
           
             ( 
            
            
             
             
               ∂ 
              
             
               y 
              
             
             
             
               ∂ 
              
             
               c 
              
             
            
           
             ⋅ 
            
            
             
             
               ∂ 
              
             
               c 
              
             
             
             
               ∂ 
              
             
               b 
              
             
            
           
             ) 
            
           
          
            ⋅ 
           
           
            
            
              ∂ 
             
            
              b 
             
            
            
            
              ∂ 
             
            
              a 
             
            
           
          
            ) 
           
          
         
           ⋅ 
          
          
           
           
             ∂ 
            
           
             a 
            
           
           
           
             ∂ 
            
           
             x 
            
           
          
         
           ) 
          
         
        
       
         f^{\prime}(\mathbf{x})=\left(\left(\left(\frac{\partial y}{\partial \mathbf{c}} \cdot \frac{\partial \mathbf{c}}{\partial \mathbf{b}}\right) \cdot \frac{\partial \mathbf{b}}{\partial \mathbf{a}}\right) \cdot \frac{\partial \mathbf{a}}{\partial \mathbf{x}}\right) 
        
       
     f′(x)=(((∂c∂y⋅∂b∂c)⋅∂a∂b)⋅∂x∂a)
总之,正向模式是从输入开始,逐步计算输出的梯度;而反向模式是从输出开始,先计算输出对中间变量的梯度,再向输入传播。
这是对于 y y y为标量的情况,那么对于 y \mathbf{y} y为向量的情况下,正向和反向又会怎么进行呢?此时我们就需要引入“雅可比矩阵(Jacobian Matrix)”了。
雅克比矩阵以及JVP、VJP
雅可比矩阵
对于函数  
     
      
       
       
         y 
        
       
         = 
        
       
         f 
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        \mathbf{y}=f(\mathbf{x}) 
       
      
    y=f(x) ,其中  
     
      
       
       
         f 
        
       
         : 
        
        
        
          R 
         
        
          n 
         
        
       
         → 
        
        
        
          R 
         
        
          m 
         
        
       
      
        f: \mathbb{R}^n \rightarrow \mathbb{R}^m 
       
      
    f:Rn→Rm ,那么  
     
      
       
       
         y 
        
       
      
        \mathbf{y} 
       
      
    y 中关于  
     
      
       
       
         x 
        
       
      
        \mathbf{x} 
       
      
    x 的梯度可以表示为雅克比矩阵 :
  
      
       
        
         
         
           J 
          
         
           f 
          
         
        
          = 
         
         
         
           [ 
          
          
           
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  1 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  1 
                 
                
               
              
             
            
            
             
             
               ⋯ 
              
             
            
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  1 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  n 
                 
                
               
              
             
            
           
           
            
             
             
               ⋮ 
              
              
               
              
             
            
            
             
             
               ⋱ 
              
             
            
            
             
             
               ⋮ 
              
              
               
              
             
            
           
           
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  m 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  1 
                 
                
               
              
             
            
            
             
             
               ⋯ 
              
             
            
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  m 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  n 
                 
                
               
              
             
            
           
          
         
           ] 
          
         
        
       
         \boldsymbol{J}_f=\left[\begin{array}{ccc} \frac{\partial y_1}{\partial x_1} & \cdots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \cdots & \frac{\partial y_m}{\partial x_n} \end{array}\right] 
        
       
     Jf=⎣⎢⎡∂x1∂y1⋮∂x1∂ym⋯⋱⋯∂xn∂y1⋮∂xn∂ym⎦⎥⎤
JVP
雅可比向量积(Jacobian-Vector Product, JVP)是指雅可比矩阵与一个向量的乘积,用于描述在给定输入变化时,输出变化的灵敏度。在 JVP 中,我们从输入出发,计算其对输出的影响。形式上,对于函数  
     
      
       
       
         y 
        
       
         = 
        
       
         f 
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        \mathbf{y}=f(\mathbf{x}) 
       
      
    y=f(x) 及输入向量  
     
      
       
       
         v 
        
       
         ∈ 
        
        
        
          R 
         
        
          n 
         
        
       
      
        \mathbf{v} \in \mathbb{R}^n 
       
      
    v∈Rn ,JVP 定义为:
  
      
       
        
        
          JVP 
         
        
           
         
        
          ( 
         
        
          v 
         
        
          ) 
         
        
          = 
         
         
         
           J 
          
         
           f 
          
         
        
          ⋅ 
         
        
          v 
         
        
       
         \operatorname{JVP}(\mathbf{v})=\boldsymbol{J}_f \cdot \mathbf{v} 
        
       
     JVP(v)=Jf⋅v
这表示在输入  
     
      
       
       
         x 
        
       
      
        \mathbf{x} 
       
      
    x 沿着向量  
     
      
       
       
         v 
        
       
      
        \mathbf{v} 
       
      
    v 变化时,输出  
     
      
       
       
         y 
        
       
      
        \mathbf{y} 
       
      
    y 的变化。
 举个例子,输出 
     
      
       
       
         y 
        
       
      
        \mathbf{y} 
       
      
    y关于  
     
      
       
        
        
          x 
         
        
          1 
         
        
       
      
        x_1 
       
      
    x1 的导数则可通过Jacobian矩阵乘以对应的one-hot向量 
     
      
       
       
         v 
        
       
      
        \mathbf{v} 
       
      
    v来得到,
  
      
       
        
         
         
           ∂ 
          
          
          
            x 
           
          
            1 
           
          
         
        
          y 
         
        
          = 
         
         
         
           J 
          
         
           f 
          
         
        
          ⋅ 
         
        
          v 
         
        
          = 
         
         
         
           [ 
          
          
           
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  1 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  1 
                 
                
               
              
             
            
           
           
            
             
             
               ⋮ 
              
              
               
              
             
            
           
           
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  m 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  1 
                 
                
               
              
             
            
           
          
         
           ] 
          
         
        
          = 
         
         
         
           [ 
          
          
           
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  1 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  1 
                 
                
               
              
             
            
            
             
             
               ⋯ 
              
             
            
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  1 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  n 
                 
                
               
              
             
            
           
           
            
             
             
               ⋮ 
              
              
               
              
             
            
            
             
             
               ⋱ 
              
             
            
            
             
             
               ⋮ 
              
              
               
              
             
            
           
           
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  m 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  1 
                 
                
               
              
             
            
            
             
             
               ⋯ 
              
             
            
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  m 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  n 
                 
                
               
              
             
            
           
          
         
           ] 
          
         
        
          ⋅ 
         
         
         
           [ 
          
          
           
            
             
             
               1 
              
             
            
           
           
            
             
             
               ⋮ 
              
              
               
              
             
            
           
           
            
             
             
               0 
              
             
            
           
          
         
           ] 
          
         
        
          , 
         
        
          v 
         
        
          ∈ 
         
         
         
           R 
          
         
           n 
          
         
        
       
         \partial_{x_1} \mathbf{y}=\boldsymbol{J}_f \cdot \mathbf{v}=\left[\begin{array}{c} \frac{\partial y_1}{\partial x_1} \\ \vdots \\ \frac{\partial y_m}{\partial x_1} \end{array}\right]=\left[\begin{array}{ccc} \frac{\partial y_1}{\partial x_1} & \cdots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \cdots & \frac{\partial y_m}{\partial x_n} \end{array}\right] \cdot\left[\begin{array}{c} 1 \\ \vdots \\ 0 \end{array}\right], \mathbf{v} \in \mathbb{R}^n 
        
       
     ∂x1y=Jf⋅v=⎣⎢⎡∂x1∂y1⋮∂x1∂ym⎦⎥⎤=⎣⎢⎡∂x1∂y1⋮∂x1∂ym⋯⋱⋯∂xn∂y1⋮∂xn∂ym⎦⎥⎤⋅⎣⎢⎡1⋮0⎦⎥⎤,v∈Rn
v \mathbf{v} v 本来可以是任意方向的向量,但是我们这里为了方便举例,就简化为 one-hot 向量,并使得结果仅为输出向量 y \mathbf{y} y 中第一个元素 y \mathbf{y} y 关于 x 1 x_1 x1 的梯度。
VJP
雅可比-向量积(Jacobian-Vector Product, VJP) 是指一个向量与雅可比矩阵的乘积,用于描述在给定输出变化时,输入变化的灵敏度。在 VJP 中,我们从输出的特定元素出发,计算其对输入的梯度。形式上,对于函数  
     
      
       
       
         y 
        
       
         = 
        
       
         f 
        
       
         ( 
        
       
         x 
        
       
         ) 
        
       
      
        \mathbf{y}=f(\mathbf{x}) 
       
      
    y=f(x) 及输出向量中的一个特定元素的 one-hot 向量  
     
      
       
       
         v 
        
       
         ∈ 
        
        
        
          R 
         
        
          m 
         
        
       
      
        \mathbf{v} \in \mathbb{R}^m 
       
      
    v∈Rm ,VJP 定义为:
  
      
       
        
        
          VJP 
         
        
           
         
        
          ( 
         
        
          v 
         
        
          ) 
         
        
          = 
         
         
         
           v 
          
         
           T 
          
         
        
          ⋅ 
         
         
         
           J 
          
         
           f 
          
         
        
       
         \operatorname{VJP}(\mathbf{v})=\mathbf{v}^T \cdot \boldsymbol{J}_f 
        
       
     VJP(v)=vT⋅Jf
这表示在输出  
     
      
       
       
         y 
        
       
      
        \mathbf{y} 
       
      
    y 沿着向量  
     
      
       
       
         v 
        
       
      
        \mathbf{v} 
       
      
    v 变化时,输入  
     
      
       
       
         x 
        
       
      
        \mathbf{x} 
       
      
    x 的变化。
 举个例子,假设我们希望计算输出  
     
      
       
       
         y 
        
       
      
        \mathbf{y} 
       
      
    y 中第一个元素  
     
      
       
        
        
          y 
         
        
          1 
         
        
       
      
        y_1 
       
      
    y1 对输入  
     
      
       
       
         x 
        
       
      
        \mathbf{x} 
       
      
    x 的导数,可以通过雅可比矩阵与对应的 one-hot 向量  
     
      
       
       
         v 
        
       
      
        \mathbf{v} 
       
      
    v 的乘积得到:
  
      
       
        
         
         
           ∂ 
          
          
          
            y 
           
          
            1 
           
          
         
        
          x 
         
        
          = 
         
         
         
           v 
          
         
           T 
          
         
        
          ⋅ 
         
         
         
           J 
          
         
           f 
          
         
        
          = 
         
        
          [ 
         
        
          1 
         
        
          , 
         
        
          0 
         
        
          , 
         
        
          … 
         
        
          , 
         
        
          0 
         
        
          ] 
         
        
          ⋅ 
         
         
         
           [ 
          
          
           
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  1 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  1 
                 
                
               
              
             
            
            
             
             
               ⋯ 
              
             
            
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  1 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  n 
                 
                
               
              
             
            
           
           
            
             
             
               ⋮ 
              
              
               
              
             
            
            
             
             
               ⋱ 
              
             
            
            
             
             
               ⋮ 
              
              
               
              
             
            
           
           
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  m 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  1 
                 
                
               
              
             
            
            
             
             
               ⋯ 
              
             
            
            
             
              
               
               
                 ∂ 
                
                
                
                  y 
                 
                
                  m 
                 
                
               
               
               
                 ∂ 
                
                
                
                  x 
                 
                
                  n 
                 
                
               
              
             
            
           
          
         
           ] 
          
         
        
          , 
         
        
          v 
         
        
          ∈ 
         
         
         
           R 
          
         
           m 
          
         
        
       
         \partial_{y_1} \mathbf{x}=\mathbf{v}^T \cdot \boldsymbol{J}_f=[1,0, \ldots, 0] \cdot\left[\begin{array}{ccc} \frac{\partial y_1}{\partial x_1} & \cdots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \cdots & \frac{\partial y_m}{\partial x_n} \end{array}\right], \mathbf{v} \in \mathbb{R}^m 
        
       
     ∂y1x=vT⋅Jf=[1,0,…,0]⋅⎣⎢⎡∂x1∂y1⋮∂x1∂ym⋯⋱⋯∂xn∂y1⋮∂xn∂ym⎦⎥⎤,v∈Rm
  
     
      
       
       
         v 
        
       
      
        \mathbf{v} 
       
      
    v 本来可以是任意方向的向量,但我们这里为了方便举例,就简化为 one-hot 向量,这样可以使得结果仅为输出向量  
     
      
       
       
         y 
        
       
      
        \mathbf{y} 
       
      
    y 中第一个元素  
     
      
       
        
        
          y 
         
        
          1 
         
        
       
      
        y_1 
       
      
    y1 关于所有输入  
     
      
       
       
         x 
        
       
      
        \mathbf{x} 
       
      
    x 的梯度。
总结
刚才介绍的JVP、VJP是不是听上去感觉和正向、方向模式似乎有些关系?
没错!
-  
正向模式对应于利用 JVP来实现输出向量(所有输出)对单一参数的求导。特别的,正向模式适用于输入维度较小的情况,因为它可以有效地逐步计算出每个输入对输出的影响,从而得到雅可比矩阵的每一列。
 -  
反向模式则对应于利用 VJP 来实现输出的某个分量对参数向量(所有参数)的求导。在这一模式下,我们从输出出发,计算输出变化相对于输入变化的灵敏度,使用雅可比矩阵的转置与输出变化的向量相乘。这使得反向模式特别适合于输出维度较小的情况,因为它允许我们在一次反向传播中计算出雅可比矩阵的每一行。
![在这里插入图片描述]()
 
为什么反向模式多用于深度学习?
在深度学习中,模型的输出通常是一个标量(如损失值)或低维向量,而输入则可能是高维(如图像、文本等)。反向模式能够高效地计算损失对所有参数的梯度,适合大规模模型训练。这种模式可以在一次前向传播后,利用一次反向传播获得所有参数的梯度,极大地提高了计算效率。因此,反向模式在深度学习中的广泛应用得益于其在处理高维输入时的优势。
posted on 2024-10-23 03:15 ACEEE-1222 阅读(209) 评论(0) 收藏 举报 来源
                    
                
                
            
        
浙公网安备 33010602011771号