一只菜鸟学机器学习的日记:梯度问题与Xavier初始化、He Kaiming初始化
本文以作者阅读《Dive into Deep Learning》为线索,融合串联了自身理解感悟、原始论文、优秀文章等。如有无意侵权,请联系本人删除。
梯度问题:
我们可以认为,
那么不难得到:
也因此,梯度是\(L-l\)个矩阵与一个梯度向量的乘积,可能会数值下溢。
梯度消失:
- 如果使用Sigmoid函数,变化还是太集中了。因此当输入很大或很小时,梯度消失。为此我们最好用ReLU函数替代之。
- 如果每一次的 梯度都减小一点,那么多层传播后梯度值会非常小。
- 如果初始值太小,向前传播过程中输入信号迅速衰减,导致激活函数的输入值非常小。
梯度爆炸:
- 特指反向传播过程中,梯度值随着层级增加而不断变大,乃至指数型增加。
- 很可能因为 \(weight\) 的初始值太大,层数过多, \(lr\) 太高(可能少加了一个小数点后的 \(0\),惨痛的教训)
参数化的对称性:
若每层的参数均初始化为 \(c\) ,那么迭代多少次参数仍然为相同的值。
参数初始化
默认初始化:使用正态分布
Xavier初始化:
目标:保持各层激活值方差稳定,确保前向传播的信号强度和反向传播的梯度强度在初始化时不衰减也不爆炸。
论文提出时只讨论了 \(sigmoid\),\(tanh\),\(softsign\) 初始函数,并未涉及 \(ReLU\) 函数,为 \(Kaiming\) 初始化提出埋下伏笔。
这里的3个函数都有饱和区,也就是梯度消失的那段区域,太大或太小时函数导数趋于 \(0\) 。
这个理论的基本原则就是:保证前向传播中激活值方差尽量不变,后向传播的梯度方差尽可能不变。 也就是说初始化阶段的激活值和梯度的期望均为 \(0\)。这就排除了 \(Sigmoid\),因此 \(Xavier\) 不适用于 \(Sigmoid\) 。
用数学语言表述,就是要激活函数 \(f(x)\) 满足:
再换句话,由观察,任意层的输入信号方差应等于其输出信号方差:
观察第 \(l\) 层的线性变换:
那么有
因此前向传播要求:\(Var(\mathcal w) \approx \frac{1}{n_{in}}\)
反向传播要求:\(Var(\mathcal w) \approx \frac{1}{n_{out}}\)
为此,Xavier 用调和平均数平衡两者:
这样,标准差就出来了:
因此初始权值应符合的正态分布:
然鹅,Xavier初始化提出的时间有点早,ReLU激活函数还没有得到广泛应用。
对于ReLY函数,Xavier初始化力不从心:
- ReLU的函数输出非对称:\(y \in [0,+∞)\)
- 负的输入反向输出时梯度为 \(0\)
- 会将 \(50\%\) 的神经元输出清零,从而
- 前向传播:\(Var(a) \approx \frac{1}{2}Var(y)\)
- 反向传播:梯度方差同样减半
面对这些问题,He初始化(Kaiming初始化)被提了出来。
对于向前传播:
对\(y_i\)加入ReLU函数得到\(a_i\):
我们为了保证每一层的输入输出方差尽可能一致,即 \(Var{(a_i)}=Var{(x_j)}\)
我们化简后不难得到:
以此类推,可以得到反向传播时,
一般情况,我们使用前向传播优先,即
我们为什么不能类比Xavier做调和平均呢?
因为ReLU的单向激活特性使得前向传播和反向传播的方差传播规律不同:
- 对前向传播,分布截断,方差复杂衰减
- 对反向传播,简单伯努利掩码,方差衰减0.5倍
- 二者传播规律、衰减因子不同,故不可使用统一的调和平均
pytorch实现:
layer = nn.Linear(64, 128)
init.kaiming_normal_(layer.weight, a=0, mode='fan_in', nonlinearity='relu')
# a:负斜率(Leaky ReLU 的情况,默认为0)
# Leaky ReLU : 负x轴设置为 ax ,而不是 0 ,通常 a = 0.01
那么,还有什么方法能够解决梯度问题吗?
拓展一下吧
- 梯度剪切:设置一个梯度阈值,强制限定梯度大小范围。
- 改用ReLU:正数部分激活函数的导数为1,很大程度缓解梯度消失问题,但是可能会引起神经元死亡问题。
- 改用LeakReLU、eLU:缓解了ReLU在负数区间的神经元死亡问题,同时保留优点。
- Batch Normalization:标准化每层的输入,使其均值为0,方差为1
- 残差网络:假设我们学习映射 \(X\to Y\),与其学习 \(Y=H(X)\),我们学习残差 \(F(X) = H(X) - X\),只要让 \(F(X)=0\) 即可,显然这样对计算机计算更加友好,因为计算机只需要让\(W\to 0,b\to 0\)。(本段存疑)
那么对于残差网络,我们令输入为 \(h_{l-1}\)有:

浙公网安备 33010602011771号