显存计算指南
1. 基础知识
-
存储单位转换:
- 1 GB = 1024 MB
- 1 MB = 1024 KB
- 1 KB = 1024 Byte
- 1 Byte = 8 Bit
-
数据精度:
- FP32: 32 Bits = 4 Bytes
- FP16: 16 Bits = 2 Bytes
2. 输入输出计算
以 Llama 13B 为例:
-
参数:
b: Batch Size = 1s: Sequence Length = 1024h: Hidden Size = 5120
-
计算:
Embedding 后的输入(使用FP16精度存储):\[ {b} \cdot {s} \cdot {h} \cdot 2 \,{Bytes} \,/\, 1024 \, ({转 KB})\,/\, 1024 \, ({转 MB}) \]- FP16的精度下占用内存: 10 MB
输出与输入维度一致:
- 输入 + 输出 = 20 MB
3. 模型参数计算
-
1B 与 1GB 的定义:(单位转换)
- 1B (Billion) 通常指十亿,计算为 (1000^3 = 10^9)。
- 1GB (Gibibyte) 是基于二进制的计算单位,$$1GB = 1024^3 = 1,073,741,824 \ \text{bytes}$$
-
FP16 的内存需求:
- FP16 (半精度浮点数) 每个参数占用 2 个字节。
- 如果有 13 亿(1.3B)的参数,总内存需求为:\[13 \times 10^9 \, \text{(参数数目)} \times 2 \, \text{(每个参数的字节数)} = 26 \, {GB} \]
- 此处的 GB 是以\(1024^3 \, \text{bytes}\)为基准。
- 26GB 是模型参数需要的总内存空间,用于存储 13 亿个 FP16 参数。这个计算基于 GB=1024³ bytes 的定义,而不是十进制单位的 \(1GB \text{(}1000^3 \, \text{bytes)}\)。
4. 优化器显存计算
Adam 优化器 (FP32):
- 梯度指数平滑: \(13 \cdot 4 = 52 \, {GB}\)
- 梯度平方指数平滑: \(13 \cdot 4 = 52 \, {GB}\)
- 模型参数存储: \(13 \cdot 4 = 52 \, {GB}\)
- 总计: \(52 + 52 + 52 = 156 \, {GB}\)
注: 维护一个 FP32 备份是为了更精确的权重更新。
5. 激活值显存计算
5.1 激活值的计算流程
激活值是在前向传播过程中产生的中间结果,通常用于后续层的计算。
例如对于一个简单的神经网络:
前向传播(Forward Propagation)
- 输入:$x $
- 权重:$ w_1, w_2 $
- 激活函数:\(Sigmoid\)
- 损失函数:\(均方误差(MSE)\)
步骤如下:
- 计算第一层的加权和:\[z_1 = x \cdot w_1 \]
- 通过激活函数得到激活值:\[a_1 = \sigma(z_1) = \frac{1}{1 + e^{-z_1}} \]
- 计算第二层的加权和:\[z_2 = a_1 \cdot w_2 \]
- 计算损失:\[\text{loss} = (y - z_2)^2 \]
反向传播(Backward Propagation)
- 对 $ w_2 $ 的梯度:\[\frac{\partial \text{loss}}{\partial w_2} = \frac{\partial \text{loss}}{\partial z_2} \cdot \frac{\partial z_2}{\partial w_2} = 2(z_2 - y) \cdot a_1 \]
- 对 $ w_1 $ 的梯度:\[\frac{\partial \text{loss}}{\partial w_1} = \frac{\partial \text{loss}}{\partial z_2} \cdot \frac{\partial z_2}{\partial a_1} \cdot \frac{\partial a_1}{\partial z_1} \cdot \frac{\partial z_1}{\partial w_1} \]其中:\[\frac{\partial a_1}{\partial z_1} = \sigma(z_1) \cdot (1 - \sigma(z_1)) \]带入后:\[\frac{\partial \text{loss}}{\partial w_1} = 2(z_2 - y) \cdot w_2 \cdot \sigma(z_1) \cdot (1 - \sigma(z_1)) \cdot x \]
可以看出,为了更新权重,我们需要计算损失函数对各个权重的梯度。此时就需要拿出之前存储过的\(z_1,a_1和z_2\)了
5.2 显存需求公式解析
对于Llama,结构稍微复杂一些,可以参考 https://zhuanlan.zhihu.com/p/673916177 这篇文章。
模型(Llama 13B),参数为:
- 序列长度:( s = 1024 )
- Batch size:( b = 1 )
- 隐藏层大小:( h = 5120 )
- Attention 头数:( a = 40 )
- 层数:( L = 40 )
显存需求公式(FP16 精度)为:
\[\text{显存需求} = s \cdot b \cdot h \cdot \left(34 + 5 \cdot a \cdot \frac{s}{h}\right) \cdot \frac{L}{1024^3} \, \text{GB}
\]
-
激活值的计算:是前向传播中通过激活函数得到的中间结果,和Batch_size, Seq_len相关,决定了后续层的输入。
-
显存需求公式:从激活值存储出发,综合考虑 Attention、隐藏层大小和层数的影响,估算显存使用量。
计算结果:
- 14.5 GB (Batch Size = 1, Sequence Length = 1024)
- 真实的Llama的batch size = 400,0000
批量计算: \(14.5 \, {GB} \cdot 400,0000 = 5800,0000 \, {GB}\)
6. 梯度值计算
- 每个参数都有一个梯度值:
- FP16: \(13 \, {B} \cdot 2 \, {Bytes} = 26 \, {GB}\)
7. 总显存计算
- 输入输出: \(20 \, {MB}\)
- 模型参数:\(26 \, {GB}\)
- 优化器: \(156 \, {GB}\)
- 激活值: \(14.5 \, {GB}\)
- 梯度值: \(26 \, {GB}\)

浙公网安备 33010602011771号