Scalable Methods for 8-bit Training of Neural Networks
概
本文针对 Batch Norm 模块在低精度 (8-bit) 的情况下进行一个合适的改进.
Range Batch Normalization
-
对于一个 \(n \times d\) 的输入 \(x = (x^{(1)}, x^{(2)}, \ldots, x^{(d)})\), Batch Norm 模块 normalize 每个维度:
\[\tag{1} \hat{x}^d = \frac{ x^{(d)} - \mu^d }{ \sqrt{ \text{Var}[x^{(d)}] } }, \]其中 \(\mu^d\) 是 \(x^{(d)}\) 上的平均, \(\text{Var}[x^{(d)}] = \frac{1}{n} \| x^{(d)} - \mu^d \|_2^2\).
-
容易发现, \(\sqrt{\text{Var}[x^{(d)}]}\) 这一项涉及平方和, 在低精度的情况下很容易导致数值不稳定. 因此, 本文希望提出一个 Range BN 模块来实现低精度下的一个鲁棒模拟:
\[\tag{2} \hat{x}^d = \frac{ x^{(d)} - \mu^d }{ C(n) \cdot \text{range}(x^{(d)} - \mu^d) }, \]其中 \(\text{range}(x) = \max(x) - \min(x)\), \(C(n) = 1 / \sqrt{2 \ln (n)}\).
-
我们来解释为什么这么做:
- 首先, 如果 \(x\) 服从高斯分布, 我们有\[0.23 \sigma \cdot \sqrt{\ln(n)} \le \mathbb{E}[\max(x^{(d)} - \mu^d)] \le \sqrt{2} \sigma \cdot \sqrt{\ln(n)}. \]
- 根据, 如果 \(x\) 关于 \(0\) 对称, 我们有 \(\mathbb{E}[\max(x)] = -\mathbb{E}[\min (x)]\), 我们有\[0.23 \sigma \cdot \sqrt{\ln(n)} \le -\mathbb{E}[\min(x^{(d)} - \mu^d)] \le \sqrt{2} \sigma \cdot \sqrt{\ln(n)}. \]
- 上面二式相加可得:\[0.325\sigma \le C(n) \cdot \text{range}(x^{(d)}- \mu^d) \le 2 \sigma. \]
- 首先, 如果 \(x\) 服从高斯分布, 我们有
-
因此, (2) 可以作为 (1) 的有效的稳定的模拟.

浙公网安备 33010602011771号