显存计算指南

1. 基础知识

  • 存储单位转换:

    • 1 GB = 1024 MB
    • 1 MB = 1024 KB
    • 1 KB = 1024 Byte
    • 1 Byte = 8 Bit
  • 数据精度:

    • FP32: 32 Bits = 4 Bytes
    • FP16: 16 Bits = 2 Bytes

2. 输入输出计算

Llama 13B 为例:

  • 参数:

    • b: Batch Size = 1
    • s: Sequence Length = 1024
    • h: Hidden Size = 5120
  • 计算:
    Embedding 后的输入(使用FP16精度存储):

    \[ {b} \cdot {s} \cdot {h} \cdot 2 \,{Bytes} \,/\, 1024 \, ({转 KB})\,/\, 1024 \, ({转 MB}) \]

    • FP16的精度下占用内存: 10 MB

    输出与输入维度一致:

    • 输入 + 输出 = 20 MB

3. 模型参数计算

  1. 1B 与 1GB 的定义:(单位转换)

    • 1B (Billion) 通常指十亿,计算为 (1000^3 = 10^9)。
    • 1GB (Gibibyte) 是基于二进制的计算单位,$$1GB = 1024^3 = 1,073,741,824 \ \text{bytes}$$
  2. FP16 的内存需求

    • FP16 (半精度浮点数) 每个参数占用 2 个字节。
    • 如果有 13 亿(1.3B)的参数,总内存需求为:

      \[13 \times 10^9 \, \text{(参数数目)} \times 2 \, \text{(每个参数的字节数)} = 26 \, {GB} \]

    • 此处的 GB 是以\(1024^3 \, \text{bytes}\)为基准。
  • 26GB 是模型参数需要的总内存空间,用于存储 13 亿个 FP16 参数。这个计算基于 GB=1024³ bytes 的定义,而不是十进制单位的 \(1GB \text{(}1000^3 \, \text{bytes)}\)

4. 优化器显存计算

Adam 优化器 (FP32):

  • 梯度指数平滑: \(13 \cdot 4 = 52 \, {GB}\)
  • 梯度平方指数平滑: \(13 \cdot 4 = 52 \, {GB}\)
  • 模型参数存储: \(13 \cdot 4 = 52 \, {GB}\)
  • 总计: \(52 + 52 + 52 = 156 \, {GB}\)

: 维护一个 FP32 备份是为了更精确的权重更新。


5. 激活值显存计算

5.1 激活值的计算流程

激活值是在前向传播过程中产生的中间结果,通常用于后续层的计算。

例如对于一个简单的神经网络:

前向传播(Forward Propagation)

  • 输入:$x $
  • 权重:$ w_1, w_2 $
  • 激活函数:\(Sigmoid\)
  • 损失函数:\(均方误差(MSE)\)

步骤如下:

  1. 计算第一层的加权和:

    \[z_1 = x \cdot w_1 \]

  2. 通过激活函数得到激活值:

    \[a_1 = \sigma(z_1) = \frac{1}{1 + e^{-z_1}} \]

  3. 计算第二层的加权和:

    \[z_2 = a_1 \cdot w_2 \]

  4. 计算损失:

    \[\text{loss} = (y - z_2)^2 \]

反向传播(Backward Propagation)

  • 对 $ w_2 $ 的梯度:

    \[\frac{\partial \text{loss}}{\partial w_2} = \frac{\partial \text{loss}}{\partial z_2} \cdot \frac{\partial z_2}{\partial w_2} = 2(z_2 - y) \cdot a_1 \]

  • 对 $ w_1 $ 的梯度:

    \[\frac{\partial \text{loss}}{\partial w_1} = \frac{\partial \text{loss}}{\partial z_2} \cdot \frac{\partial z_2}{\partial a_1} \cdot \frac{\partial a_1}{\partial z_1} \cdot \frac{\partial z_1}{\partial w_1} \]

    其中:

    \[\frac{\partial a_1}{\partial z_1} = \sigma(z_1) \cdot (1 - \sigma(z_1)) \]

    带入后:

    \[\frac{\partial \text{loss}}{\partial w_1} = 2(z_2 - y) \cdot w_2 \cdot \sigma(z_1) \cdot (1 - \sigma(z_1)) \cdot x \]

可以看出,为了更新权重,我们需要计算损失函数对各个权重的梯度。此时就需要拿出之前存储过的\(z_1,a_1和z_2\)

5.2 显存需求公式解析

对于Llama,结构稍微复杂一些,可以参考 https://zhuanlan.zhihu.com/p/673916177 这篇文章。
模型(Llama 13B),参数为:

  • 序列长度:( s = 1024 )
  • Batch size:( b = 1 )
  • 隐藏层大小:( h = 5120 )
  • Attention 头数:( a = 40 )
  • 层数:( L = 40 )

显存需求公式(FP16 精度)为:

\[\text{显存需求} = s \cdot b \cdot h \cdot \left(34 + 5 \cdot a \cdot \frac{s}{h}\right) \cdot \frac{L}{1024^3} \, \text{GB} \]

  • 激活值的计算:是前向传播中通过激活函数得到的中间结果,和Batch_size, Seq_len相关,决定了后续层的输入。

  • 显存需求公式:从激活值存储出发,综合考虑 Attention、隐藏层大小和层数的影响,估算显存使用量。

    计算结果:

    • 14.5 GB (Batch Size = 1, Sequence Length = 1024)
    • 真实的Llama的batch size = 400,0000
      批量计算: \(14.5 \, {GB} \cdot 400,0000 = 5800,0000 \, {GB}\)

6. 梯度值计算

  • 每个参数都有一个梯度值:
    • FP16: \(13 \, {B} \cdot 2 \, {Bytes} = 26 \, {GB}\)

7. 总显存计算

  • 输入输出: \(20 \, {MB}\)
  • 模型参数:\(26 \, {GB}\)
  • 优化器: \(156 \, {GB}\)
  • 激活值: \(14.5 \, {GB}\)
  • 梯度值: \(26 \, {GB}\)

合计: \(222.5 {GB}\)

posted @ 2024-12-07 19:18  AAA建材王师傅  阅读(195)  评论(0)    收藏  举报