Torch compile原理解析
问题:torch的eager模式是即时编译,无法跨算子优化,导致小kernel启动耗时高,模型推理或训练优化不充分。
Torch 1.0的解决方案:torch.jit可以捕获图并优化模型,但仍存在图捕获不完整的问题(trace不能获取if和for等条件控制流;script只兼容python的部分,不能实现全图优化),且存在python生态兼容性差的问题(不能与其他python库进行联合性能优化)。
Torch 2.0的解决方案2:torch.compile的目标:
目标一:完整的图捕获能力,通过CPython 的 Frame Evaluation API,直接在 Python 字节码层面拦截执行,提取可以编译的纯张量计算子图,遇到不支持的 Python 构造(比如 print、文件 IO、复杂的控制流)时,不是报错,而是在那个位置"断图"(graph break),回退到 eager 模式执行那一段,然后继续尝试捕获后面的部分。
TORCH_LOGS=graph_breaks python train.py
TORCH_LOGS=+fusion 可以看到完整的融合过程
TORCH_LOGS=output_code 输出编译产物
目标二:正确性优先,如果编译器不确定优化是安全的,就不做优化,通过"守卫"(guards)机制实现。
目标三:性能提升显著且可预测,三大性能提升来自算子融合,AOT Autograd,CUDA Graphs,第一次调用较慢,之后调用很快,因为编译器需要先分析字节码生成 FX Graph(Dynamo),前后向联合优化(AOT Autograd),算子分解和融合(Inductor),生成 Triton/C++ 代码并调用编译器,之后直接执行编译好的代码。
torch.compile
torch.compile 的编译流水线包含3批次:
三大批次,每批次5阶段、4优化、2路径:每个算子都需要通过这3个批次,到最终生成代码。
TorchDynamo + FX Graph
Python 函数
↓
[TorchDynamo] 字节码分析 + 图捕获
↓ 生成
FX Graph + Guards
↓
┌──────────────────────────────────┐
│ 批次 1:FX Graph Passes(Lowering 前) │
│ - 操作对象:FX Graph(Python 级别的 Graph/Node) │
│ - 算子粒度:ATen/Prims ops │
│ - 操作目标:减少节点数,增加Meta信息,清晰依赖关系 │
├──────────────────────────────────┤
│ ├─ 阶段1:Pre-Grad Passes │
│ │ • 结构规范化(split/concat/reshape 消除) │
│ │ • 形状传播(ShapeProp/FakeTensorProp,添加Meta信息) │
│ │ • Padding 调整 │
│ │ • 消除冗余clone │
│ │ • 拓扑排序(确保是拓扑序) │
│ │ │
│ ├─ 阶段2:AOT Autograd(训练阶段生成 Joint Graph) │
│ │ • 追踪前向+反向 │
│ │ • 应用分解表,生成可微分的低层算子,训练时生成反向图 │
│ │ • Backward管理精度:使用混合精度来加速和节省内存 │
│ │ • torch/_decomp/decompositions.py │
│ │ │
│ ├─ 阶段3:Joint Graph Passes(全局视角) │
│ │ • Peephole 优化(常量折叠、冗余视图消除transpose) │
│ │ • 随机数处理(Dropout,不保存mask tensor,只保存随机数)│
│ │ │
│ ├─ 阶段4:Min-Cut Partitioning │
│ │ • 分析 saved tensors,最小化需要保存的中间张量 │
│ │ • 图拆分为 fw_graph + bw_graph │
│ │ │
│ └─ 阶段5:Post-Grad Passes │
│ • 局部性重排(reorder_for_locality) │
│ • No-op 消除 │
│ • Reinplace(减少内存分配) │
│ • 分布式优化(分布式通信算子:多次通信融合,计算通信重叠) │
│ • FSDP优化(预取,内存压力检测,异步通信正确性保证,硬件特化)│
└───────────────────────────────────┘
TorchInductor + Inductor IR
FX Graph
↓ 前向图 + 后向图
┌──────────────────────────────────┐
│ 转换点:Lowering(FX Graph → Inductor IR) │
│ - torch/_inductor/lowering.py │
│ - 每个 ATen op → Inductor IR 节点 │
│ - 高层 → 低层:aten.add 一 Pointwise │
└──────────────────────────────────┘
↓
┌──────────────────────────────────┐
│ 批次 2:Inductor IR Passes(Lowering 后,Codegen 前) │
│ - 操作对象:Inductor IR(TensorBox/Buffer/Pointwise) │
│ - 这批 passes 主要在 Scheduler 中 │
├──────────────────────────────────┤
│ ├─ 依赖分析(Dependency Analysis) │
│ │ • 构建 IR 节点间的数据依赖图:强依赖RAW,弱依赖WAR/WAW │
│ │ • _inductor/scheduler.py: compute_dependencies() │
│ │ • 确定拓扑排序 │
│ │ • _inductor/scheduler.py: topological_sort_schedule() │
│ │ • 循环依赖检测 │
│ │ • _inductor/memory.py: validate_graph_acyclic() │
│ │ │
│ ├─ 融合决策(Fusion) │
│ │ • Pointwise 逐元素操作,融合 │
│ │ • Reduction 规约操作,融合策略:Persistent/Split-K │
│ │ • ExternKernel 调用外部库 │
│ │ • _inductor/scheduler.py: can_fuse() │
│ │ │
│ ├─ 内存规划(Memory Planning) │
│ │ • Buffer 生命周期分析 │
│ │ • 内存复用 Inplace │
│ │ • 峰值内存估算与重排 │
│ │ • memory.py │
│ │ │
│ └─ 布局优化(Layout Transformation) │
│ • Stride 重排 │
│ • Layout 传播 │
│ • Channels-last 转换 │
└──────────────────────────────────┘
↓
┌──────────────────────────────────┐
│ 批次 3:Codegen Passes(代码生成时的优化) │
│ - 操作对象:Triton/C++ AST │
│ - 基于Jinja2填充模板,生成设备特定优化 │
│ - torch/_inductor/codegen/ │
├──────────────────────────────────┤
│ ├─ Triton Codegen(GPU) │
│ │ • Tiling 策略选择 │
│ │ • Block size 优化 512-2048:太大,寄存器和共享内存不够用;太小,内核启动开销大 │
│ │ • Autotune(多配置选择):同个kernel可多版本自动调优 │
│ │ • _inductor/codegen/triton.py │
│ │ • @triton_heuristics.pointwise 装饰器 │
│ │ • @triton.jit triton jit 装饰器 + Kernel 主体 │
│ │ │
│ └─ C++ Codegen(CPU) │
│ • SIMD 向量化 │
│ • OpenMP 并行化 │
│ • codegen/cpp.py │
└──────────────────────────────────┘
↓
设备内核 (Triton/C++) triton编译器或g++编译后,通过PyBind加载回Python
↓
[Runtime] 守卫检查 + 缓存管理 + 重编译触发
↓
执行
TorchDynamo
Guards的目的是为了保障优化后的代码能正确执行。在构建 FX Graph 的过程中,TorchDynamo 会记录所有的假设(Tensor类型或shape等),这些假设会被编译成"守卫"代码,每次调用编译后的函数前,先运行这个守卫函数。如果返回 False,说明假设被违反,需要重编译。
FX Graph的对象是python对象,方便开发者直接打印与调试。包括Graph,Node,GraphModule。
1️⃣ Node:图中的一个操作,包括五种操作类型:placeholder输入参数;call_function调用函数;call_method调用方法;get_attr获取属性;output返回值。
2️⃣ Graph:节点的有序容器,维护了一个节点的双向链表,支持高效的插入、删除、遍历操作。
3️⃣ GraphModule:可执行的图,继承自 nn.Module,把一个 Graph 包装成可以直接调用的模块。
从 Python 函数生成 FX Graph 的过程,核心机制是用 Proxy 对象来拦截所有操作,即符号追踪,只对操作进行追踪,而不执行。追踪过程中,通过 Fake Tensor 记录形状和类型等信息,但不占用真实的内存。
FX Graph 支持各种子图匹配与替换,可以通过SubgraphMatcher实现。
Runtime 机制
1)守卫系统:编译时,Dynamo 生成所有guards假设;运行时,guards假设被验证。
2)缓存策略:变体管理,一个函数可能有多个编译版本,对应不同的输入特征;
3)重编译触发与动态形状泛化:形状、设备、dtype 等变化触发重编译
第一次调用(冷启动):
Frame Evaluation:Dynamo 拦截函数调用
Graph Capture:分析字节码,构建 FX Graph + Guard
Guard Generation:记录 x 的所有属性
AOT Autograd(如果是训练模式):生成前后向图
IR Lowering:FX → Core ATen → Prims
Inductor:
算子分解:relu, add
融合分析:两个 pointwise 可以融合
代码生成:生成 Triton kernel
编译:调用 Triton 编译器
Cache:存储编译结果
Execute:运行生成的 kernel
Return:返回结果
第二次调用(热路径):
Guard Check:检查 x 的属性(< 1 微秒)
Cache Hit:找到编译版本
Execute:直接运行 kernel
Return:返回结果
形状变化时(重发重编译):
Guard Check:失败(形状不匹配)
Recompilation:
尝试动态形状泛化
重新编译(约 1-5 秒)
Cache Update:存储新版本
Execute:运行新 kernel
Return:返回结果
尝试案例
import torch
import torch.nn as nn
# 启用调试模式
import os
os.environ['TORCH_COMPILE_DEBUG'] = '1'
# 定义模型
class Model(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(4, 4))
def forward(self, x):
mm = torch.mm(x, self.weight)
return torch.relu(mm)
# 编译
model = Model()
compiled = torch.compile(model, backend="inductor")
# 运行(触发编译)
x = torch.randn(4, 4, requires_grad=True)
output = compiled(x)
loss = output.sum()
loss.backward()
# 查看生成的文件
# 在 /tmp/torchinductor_<username>/ 目录下
# fx_graph_readable.py - 各阶段的 FX Graph
# output_code.py - 生成的 Triton/C++ 代码
参考文献
PyTorch 2.x 编译系统详解:https://mp.weixin.qq.com/s/kqhkyOgk45adnHHFw7O_GQ

浙公网安备 33010602011771号