在深度学习中,计算图(Computational Graph) 是一种用于表示数学表达式或计算过程的有向无环图(DAG),它将复杂的计算分解为一系列基本运算(如加减乘除、激活函数等),并通过节点(Node)和边(Edge)来表示数据流动和依赖关系。计算图是自动求导(如反向传播)的基础,也是现代深度学习框架(如 PyTorch、TensorFlow)的核心抽象。
- 表示运算(Operation),如加法、乘法、激活函数等。
- 每个节点可以有零个或多个输入,以及一个或多个输出。
- 表示数据流动方向,边连接的是节点之间的输入和输出。
- 边携带的是张量(Tensor),即多维数组。
- 静态图(如 TensorFlow 1.x):先定义计算图,再执行计算(需显式编译)。
- 动态图(如 PyTorch、TensorFlow 2.x):在运行时动态构建计算图,代码执行时即时生成。
- 通过计算图从输入到输出的路径,依次执行各节点的运算,得到最终结果(如模型预测值)。
- 基于链式法则,从输出端反向计算每个节点对输入的梯度,用于参数更新。
- 计算图记录了前向传播的路径和中间结果,为反向传播提供必要信息。
- 框架可分析计算图,优化计算路径(如内存复用、算子融合)。
- 支持分布式计算和硬件加速(如 GPU 并行计算)。
![]()
import torch
# 创建可求导的张量
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
# 定义计算(动态构建计算图)
a = x + y # 加法节点
z = a * y # 乘法节点
# 反向传播计算梯度
z.backward() # 从z开始反向传播
# 输出梯度
print(f"dz/dx = {x.grad}") # 输出: 3.0
print(f"dz/dy = {y.grad}") # 输出: 8.0
计算图构建过程:
- 创建张量 x 和 y 时,设置
requires_grad=True 标记需要求导。
- 执行
a = x + y 和 z = a * y 时,PyTorch 自动记录每个运算的输入和输出,构建计算图。
- 调用
z.backward() 时,基于计算图反向传播梯度,计算每个变量的导数。
| 特性 |
静态图(如 TensorFlow 1.x) |
动态图(如 PyTorch) |
| 构建方式 |
先定义图结构,再执行(需显式编译) |
运行时动态构建(代码执行即构建) |
| 灵活性 |
低(修改图结构需重新编译) |
高(支持条件语句、循环等动态操作) |
| 调试难度 |
高(需先构建图,难以实时查看中间结果) |
低(可像普通 Python 代码一样调试) |
| 性能优化 |
优(编译时可全局优化计算路径) |
一般(动态优化能力有限) |
| 适用场景 |
部署需求高、计算图固定的场景(如移动端) |
研究、快速迭代、动态网络结构的场景 |
- 梯度检查点(Gradient Checkpointing):在反向传播时重新计算部分中间结果,减少内存占用。
- 内存复用:框架自动回收不再需要的张量内存。
- 算子融合(Operator Fusion):将多个连续运算合并为一个,减少内存访问开销(如将卷积 + ReLU 合并)。
- 分布式计算:将计算图分割到多个设备(如 GPU、TPU)并行执行。
- 静态计算图可导出为硬件友好的格式(如 ONNX),便于在移动端或专用硬件上部署。
计算图是深度学习框架的核心抽象,通过将复杂计算分解为基本运算并记录依赖关系,支持高效的前向传播和反向传播。动态图提供了灵活的开发体验,而静态图在性能优化和部署方面更具优势。理解计算图的工作原理有助于更高效地使用深度学习框架,并优化模型训练和推理过程。