Attention 显存计算 & 推理训练复杂度

在深度学习中,无论是进行模型推理(Inference)还是训练(更新参数),都需要占用大量GPU显存(VRAM)。然而,这两者在显存需求和构成上存在显著差异。总体而言,模型训练所需的显存远大于推理

下面将详细解析在两个阶段中,显存分别被哪些部分占用。

一、 模型推理(Inference)时的显存占用

模型推理是指使用已经训练好的模型进行预测。此阶段的显存占用相对简单,主要包括以下几个部分:

  1. 模型权重(Model Weights): 这是显存占用的主要部分。它指的是模型所有可学习参数(如卷积核的权重、全连接层的权重、偏置项等)所占用的空间。例如,一个以16位浮点数(FP16)存储的70亿参数(7B)大模型,仅模型权重就需要 70亿 * 2字节/参数 ≈ 14 GB 的显存。

  2. 激活值(Activations): 在前向传播过程中,每一层的计算输出被称为激活值。这些激活值需要被存储在显存中,以作为下一层的输入。在推理阶段,由于通常一次只处理一个或一小批(small batch)输入,并且计算过的层可以释放部分不再需要的激活值,因此这部分显存占用相对可控。其大小与批处理大小(Batch Size)和序列长度(Sequence Length)密切相关。

  3. 计算缓冲区(Workspace/Buffers): 深度学习框架(如TensorFlow, PyTorch)和底层加速库(如cuDNN)在执行特定计算(如卷积、矩阵乘法)时,会预先分配一些临时的显存空间作为缓冲区,以优化计算性能。这部分显存占用通常是动态的。

  4. 框架和CUDA上下文开销(Framework and CUDA Context Overhead): 初始化GPU和深度学习框架本身会占用一部分固定的显存,用于存储CUDA上下文、内核代码以及框架的运行时组件。这部分开销通常在几百MB到1GB以上。

推理时显存占用小结:
总显存 ≈ 模型权重 + 激活值 + 计算缓冲区 + 框架开销
其中,模型权重是绝对的大头

二、 模型训练(更新参数)时的显存占用

模型训练不仅包含推理时的所有步骤(即前向传播),还增加了反向传播和参数更新的过程,因此需要存储更多的中间信息。其显存占用主要包括:

  1. 模型权重(Model Weights): 与推理时相同,这是基础的显存开销。

  2. 激活值(Activations): 与推理时类似,但更为关键。在训练中,前向传播过程中计算出的所有激活值必须被完整保留,因为它们在后续的反向传播中需要用来计算梯度。因此,这部分的显存占用比推理时要大得多,并且直接受批处理大小(Batch Size)和模型深度的影响。

  3. 梯度(Gradients): 在反向传播过程中,会为模型中的每一个可学习参数计算一个梯度。这些梯度矩阵的尺寸与参数矩阵完全相同,因此它们会占用与模型权重几乎同样大小的显存空间。例如,一个14GB的FP16模型,其梯度也会占用约14GB的显存。

  4. 优化器状态(Optimizer States): 这是训练时显存占用的一个主要增长点。现代优化器(如Adam、AdamW)为了实现更高效的收敛,会为每个参数维护额外的状态信息。

    • Adam/AdamW优化器:会为每个参数存储两个状态量:一阶动量(momentum)和二阶动量(variance)。如果以32位浮点数(FP32)存储,这两个状态将占用 参数量 * 4字节 * 2 = 8倍参数量 的字节数。即使在混合精度训练中对优化器状态也使用FP32,它所占用的显存也是模型权重的2倍
  5. 计算缓冲区和框架开销: 与推理时相同,但由于训练过程更复杂,这部分开销可能会略有增加。

训练时显存占用小结:
总显存 ≈ 模型权重 + 激活值 + 梯度 + 优化器状态 + 缓冲区与框架开销

一个广为流传的粗略估算方法是,对于使用Adam优化器的全参数微调(Full Fine-tuning),显存占用大约是:

  • 模型权重: 1倍参数大小
  • 梯度: 1倍参数大小
  • 优化器状态: 2倍参数大小(Adam)
  • 激活值和其他: 变化较大,但通常也相当可观

因此,不考虑激活值的情况下,训练所需的显存至少是模型参数大小的4倍。这解释了为什么一个可以在24GB显卡上进行推理的模型,可能需要80GB甚至更多的显存才能进行全参数微调。

总结与对比

显存组成部分 模型推理(Inference) 模型训练(Parameter Updating)
模型权重 ✔️(主要部分) ✔️
激活值 ✔️(通常较小,与batch size相关) ✔️(必须保留所有,占用较大)
梯度 ✔️(与模型权重大小相当)
优化器状态 ✔️(通常是模型权重的2倍或更多)
框架/缓冲区 ✔️ ✔️

为了降低训练时的显存压力,业界发展出了多种优化技术,如:

  • 混合精度训练(Mixed Precision): 使用FP16/BF16存储权重、激活和梯度,降低显存占用。
  • 梯度累积(Gradient Accumulation): 通过多次前向/反向传播累积梯度,然后进行一次参数更新,从而在不减少总批处理量的情况下使用更小的单次批处理大小。
  • 梯度检查点(Gradient Checkpointing/Activation Recomputation): 不存储所有中间激活值,而是在反向传播时重新计算它们,用计算时间换取显存空间。
  • 模型并行与分布式训练(ZeRO等): 将模型的权重、梯度和优化器状态分散到多张GPU上,共同承担显存压力。
  • LoRA等参数高效微调(PEFT)技术: 只训练模型中一小部分新增的参数,从而极大地减少了梯度和优化器状态所需的显存。

深入解析注意力机制的计算复杂度

在深度学习领域,尤其是自然语言处理任务中,注意力(Attention)机制已成为一项核心技术,其成功推动了Transformer等模型的突破性发展。然而,强大的性能背后也存在着不容忽视的计算开销。理解注意力机制的计算复杂度对于模型优化和在资源受限环境下的部署至关重要。

其核心计算复杂度主要受两个因素影响:序列长度(n)和表示维度(d)。

标准点积注意力(Standard Dot-Product Attention)

对于一个输入序列,其包含n个token,每个token被表示为一个d维的向量。标准点积注意力的计算过程如下:

  1. 生成Query (Q), Key (K), Value (V) 矩阵: 通过将输入序列的嵌入矩阵(维度为 n x d)分别乘以三个权重矩阵(维度均为 d x d),得到Q、K、V三个矩阵,它们的维度仍然是 n x d。这一步的计算复杂度为 \(O(n \cdot d^2)\)

  2. 计算注意力分数: 核心步骤是计算Q和K的点积,即 \(Q \cdot K^T\)。这是一个 (n x d) 矩阵与一个 (d x n) 矩阵的乘法,结果是一个 (n x n) 的注意力分数矩阵。该步骤的计算复杂度为 \(O(n^2 \cdot d)\)

  3. 缩放与Softmax: 将注意力分数矩阵除以一个缩放因子(通常是 \(\sqrt{d_k}\),其中 \(d_k\) 是K向量的维度),然后对每一行应用Softmax函数,得到归一化的注意力权重。此步骤的复杂度为 \(O(n^2)\)

  4. 加权求和: 将得到的注意力权重矩阵与V矩阵相乘,得到最终的输出。这是一个 (n x n) 矩阵与 (n x d) 矩阵的乘法,复杂度为 \(O(n^2 \cdot d)\)

综合来看,标准点积注意力的计算复杂度为 \(O(n^2 \cdot d + n \cdot d^2)\)。在实际应用中,序列长度n通常远大于表示维度d,因此复杂度通常被简化为 \(O(n^2 \cdot d)\),即计算复杂度与序列长度的平方成正比。这使得处理长序列成为一个巨大的挑战。

多头注意力(Multi-Head Attention)

多头注意力机制并行地执行多次点积注意力计算。具体来说,它将d维的Q、K、V向量线性投射到h个较低的维度 \(d_k = d/h\) 上,分别计算注意力,然后将h个头的输出拼接并再次进行线性变换。

虽然听起来更复杂,但多头注意力的计算复杂度与单头注意力基本相同。这是因为在多头机制下,每个头的计算维度降低了。对于每个头,计算 \(Q \cdot K^T\) 的复杂度为 \(O(n^2 \cdot d_k)\)。由于有h个头,总的复杂度为 \(h \cdot O(n^2 \cdot d_k) = O(n^2 \cdot h \cdot d_k) = O(n^2 \cdot d)\)

因此,多头注意力的计算复杂度仍然是 \(O(n^2 \cdot d)\)。虽然总的计算量相近,但多头的设计允许模型在不同的表示子空间中共同学习来自不同位置的信息,从而提升了模型的表达能力。

交叉注意力(Cross-Attention)

交叉注意力与自注意力(Self-Attention)不同,它的Query来自于一个序列,而Key和Value来自于另一个序列。假设第一个序列的长度为m,第二个序列的长度为n,表示维度均为d。其计算复杂度的主要瓶颈在于计算Query矩阵(m x d)和Key转置矩阵(d x n)的乘积,得到一个 (m x n) 的注意力矩阵。因此,交叉注意力的计算复杂度为 \(O(m \cdot n \cdot d)\)

稀疏注意力(Sparse Attention)

为了缓解标准注意力机制中与序列长度平方成正比的计算瓶颈,研究人员提出了多种稀疏注意力机制。其核心思想是,对于每个token,不必计算其与序列中所有其他token的注意力分数,而是只关注一个有限的子集。

常见的稀疏模式包括:

  • 滑动窗口(Sliding Window): 每个token只关注其邻近的固定大小窗口内的token。
  • 扩张滑动窗口(Dilated Sliding Window): 类似于空洞卷积,跳过一些位置来扩大感受野。
  • 全局注意力(Global Attention): 预先选择一些“全局”token,所有其他token都需要关注这些全局token。
  • 随机注意力(Random Attention): 随机选择一些token进行关注。

通过这些稀疏化处理,注意力矩阵不再是密集的 \(n \times n\) 矩阵,而是只计算其中的一部分元素。这使得计算复杂度可以显著降低,例如降至 \(O(n \log n)\)\(O(n \sqrt{n})\),甚至在一些设计中可以达到线性复杂度 \(O(n)\)。这使得处理数千甚至数万长度的序列成为可能,极大地扩展了基于Transformer架构的应用范围。

综上所述,注意力机制的计算复杂度是其在实际应用中需要仔细考量的重要因素。虽然标准的自注意力机制功能强大,但其二次方复杂度限制了其处理长序列的能力,而各种稀疏注意力机制则为解决这一挑战提供了有效的途径。

矩阵乘法复杂度

当然,我们来详细解析一下矩阵相乘的计算复杂度。这是一个在计算机科学、工程和机器学习领域都至关重要的基础概念。

矩阵相乘的复杂度取决于所使用的具体算法。我们通常从最基础的定义法(或称朴素算法)开始,然后讨论更高级的理论算法。

1. 标准(朴素)算法的复杂度

这是我们在教科书中最先学到的,也是最直观的实现方法。

假设我们有两个矩阵:

  • 矩阵 A 的维度是 m × n
  • 矩阵 B 的维度是 n × p

要使两者能够相乘,第一个矩阵的列数(n)必须等于第二个矩阵的行数(n)。相乘得到的结果矩阵 C 的维度将是 m × p

计算过程分析:

为了计算结果矩阵 C 中的任意一个元素 \(C\_{ij}\)(第 i 行,第 j 列的元素),我们需要将矩阵 A 的第 i 行与矩阵 B 的第 j 列进行点积运算。

  • A 的第 i 行是一个有 n 个元素的行向量。
  • B 的第 j 列是一个有 n 个元素的列向量。

计算这个点积需要 n 次乘法和 n-1 次加法。我们可以将操作次数近似为 n

因为结果矩阵 C 总共有 m × p 个元素,而每个元素都需要 n 次乘法操作,所以总的计算复杂度是:

\[\text{总操作数} = m \times p \times n \]

使用大O表示法(Big O notation),我们说标准矩阵乘法的计算复杂度为:

\[O(m \cdot n \cdot p) \]

特例:两个方阵相乘

在理论分析和很多应用中,一个常见的特例是两个 n × n 的方阵相乘。在这种情况下,m = n = p

复杂度就变成了:

\[O(n \cdot n \cdot n) = O(n^3) \]

这意味着,如果将方阵的边长增加一倍,计算量将增加到原来的8倍 (\(2^3=8\))。这就是为什么对于大矩阵,朴素算法的计算成本会急剧上升。

代码示例(伪代码):
下面三层嵌套循环的结构清晰地展示了 \(O(n^3)\) 的复杂度。

# A是 n x n 矩阵, B是 n x n 矩阵
C = initialize_matrix(n, n)

for i in range(n):  # 遍历结果矩阵的行
    for j in range(n):  # 遍历结果矩阵的列
        sum = 0
        for k in range(n):  # 计算点积
            sum += A[i][k] * B[k][j]
        C[i][j] = sum

2. 更快的理论算法(Sub-cubic Algorithms)

\(O(n^3)\) 并非理论上的最优解。研究人员已经发现了一些复杂度低于立方级的算法,尽管它们在实际应用中不一定常用。

斯特拉森算法(Strassen's Algorithm)

这是第一个被发现的“快速”矩阵乘法算法,由Volker Strassen于1969年提出。它采用分治策略,将大矩阵递归地分解为小矩阵进行计算。

  • 复杂度: \(O(n^{\log_2 7}) \approx O(n^{2.807})\)

虽然其渐进复杂度低于 \(O(n^3)\),但由于算法本身更复杂,递归开销较大,且对数值稳定性有一定影响,所以它通常只在矩阵非常大(例如 n > 128 或更大,具体阈值取决于实现和硬件)时才比高度优化的标准算法更快。

Coppersmith-Winograd算法及其后继

更高级的算法,如Coppersmith-Winograd算法,在理论上进一步降低了复杂度。目前理论上最快的算法复杂度约为 \(O(n^{2.3728596})\)。然而,这些算法的常数因子极大,实现极其复杂,导致它们几乎没有任何实际应用价值,仅存在于理论计算机科学的讨论中。

最近,DeepMind使用人工智能(AlphaTensor)发现了新的、更高效的矩阵乘法算法,在特定大小的矩阵上超越了人类设计的算法,但这些仍属于前沿研究领域。

3. 实践中的考量:硬件与库

在实际应用中,我们几乎从不手动编写三层循环来进行矩阵乘法。我们会依赖高度优化的数学库。

  • 并行计算: 现代CPU和特别是GPU(图形处理器)拥有大量的并行处理单元。它们可以将矩阵乘法分解成成千上万个可以同时执行的小任务。虽然这不会改变总的计算操作数(即理论复杂度),但它能将计算时间缩短几个数量级。
  • 优化库 (BLAS, cuBLAS): 像BLAS (Basic Linear Algebra Subprograms)、Intel MKL、NVIDIA cuBLAS这样的底层库,其矩阵乘法实现经过了极致的优化。它们利用了硬件的SIMD指令(单指令多数据流)、缓存层次结构和并行计算能力,性能远非朴素的三重循环可比。

总结

算法/方法 复杂度 (n x n 方阵) 实用性
标准(朴素)算法 \(O(n^3)\) 非常高。是所有现代优化库实现的基础逻辑。
斯特拉森算法 \(O(n^{2.807})\) 中等。在处理非常大的稠密矩阵时有一定应用。
Coppersmith-Winograd等 \(\approx O(n^{2.373})\) 极低。仅有理论意义。
高度优化的库 (如cuBLAS) \(O(n^3)\) 最高。在实践中是性能的黄金标准。

因此,当被问及矩阵相乘的复杂度时:

  • 标准答案/理论答案\(O(m \cdot n \cdot p)\),对于方阵是 \(O(n^3)\)
  • 进阶答案:可以提及存在 \(O(n^{2.807})\) 等更快的理论算法,但它们在实践中受限。
  • 工程实践中的答案:理论复杂度是 \(O(n^3)\),但实际性能由硬件并行能力和优化库的效率决定。
posted @ 2025-08-12 00:52  AikNr  阅读(208)  评论(0)    收藏  举报