Scalable Methods for 8-bit Training of Neural Networks

Banner R., Hubara I., Hoffer E. and Soudry D. Scalable methods for 8-bit training of neural networks. NeurIPS, 2018.

本文针对 Batch Norm 模块在低精度 (8-bit) 的情况下进行一个合适的改进.

Range Batch Normalization

  • 对于一个 \(n \times d\) 的输入 \(x = (x^{(1)}, x^{(2)}, \ldots, x^{(d)})\), Batch Norm 模块 normalize 每个维度:

    \[\tag{1} \hat{x}^d = \frac{ x^{(d)} - \mu^d }{ \sqrt{ \text{Var}[x^{(d)}] } }, \]

    其中 \(\mu^d\)\(x^{(d)}\) 上的平均, \(\text{Var}[x^{(d)}] = \frac{1}{n} \| x^{(d)} - \mu^d \|_2^2\).

  • 容易发现, \(\sqrt{\text{Var}[x^{(d)}]}\) 这一项涉及平方和, 在低精度的情况下很容易导致数值不稳定. 因此, 本文希望提出一个 Range BN 模块来实现低精度下的一个鲁棒模拟:

    \[\tag{2} \hat{x}^d = \frac{ x^{(d)} - \mu^d }{ C(n) \cdot \text{range}(x^{(d)} - \mu^d) }, \]

    其中 \(\text{range}(x) = \max(x) - \min(x)\), \(C(n) = 1 / \sqrt{2 \ln (n)}\).

  • 我们来解释为什么这么做:

    1. 首先, 如果 \(x\) 服从高斯分布, 我们有

      \[0.23 \sigma \cdot \sqrt{\ln(n)} \le \mathbb{E}[\max(x^{(d)} - \mu^d)] \le \sqrt{2} \sigma \cdot \sqrt{\ln(n)}. \]

    2. 根据, 如果 \(x\) 关于 \(0\) 对称, 我们有 \(\mathbb{E}[\max(x)] = -\mathbb{E}[\min (x)]\), 我们有

      \[0.23 \sigma \cdot \sqrt{\ln(n)} \le -\mathbb{E}[\min(x^{(d)} - \mu^d)] \le \sqrt{2} \sigma \cdot \sqrt{\ln(n)}. \]

    3. 上面二式相加可得:

      \[0.325\sigma \le C(n) \cdot \text{range}(x^{(d)}- \mu^d) \le 2 \sigma. \]

  • 因此, (2) 可以作为 (1) 的有效的稳定的模拟.

代码

[official-code]

posted @ 2025-01-06 21:20  馒头and花卷  阅读(46)  评论(0)    收藏  举报