Torch compile原理解析

问题:torch的eager模式是即时编译,无法跨算子优化,导致小kernel启动耗时高,模型推理或训练优化不充分。

Torch 1.0的解决方案:torch.jit可以捕获图并优化模型,但仍存在图捕获不完整的问题(trace不能获取if和for等条件控制流;script只兼容python的部分,不能实现全图优化),且存在python生态兼容性差的问题(不能与其他python库进行联合性能优化)。

Torch 2.0的解决方案2:torch.compile的目标:

目标一:完整的图捕获能力,通过CPython 的 Frame Evaluation API,直接在 Python 字节码层面拦截执行,提取可以编译的纯张量计算子图,遇到不支持的 Python 构造(比如 print、文件 IO、复杂的控制流)时,不是报错,而是在那个位置"断图"(graph break),回退到 eager 模式执行那一段,然后继续尝试捕获后面的部分。

TORCH_LOGS=graph_breaks python train.py
TORCH_LOGS=+fusion 可以看到完整的融合过程
TORCH_LOGS=output_code 输出编译产物

目标二:正确性优先,如果编译器不确定优化是安全的,就不做优化,通过"守卫"(guards)机制实现。

目标三:性能提升显著且可预测,三大性能提升来自算子融合,AOT Autograd,CUDA Graphs,第一次调用较慢,之后调用很快,因为编译器需要先分析字节码生成 FX Graph(Dynamo),前后向联合优化(AOT Autograd),算子分解和融合(Inductor),生成 Triton/C++ 代码并调用编译器,之后直接执行编译好的代码。

torch.compile

torch.compile 的编译流水线包含3批次:
三大批次,每批次5阶段、4优化、2路径:每个算子都需要通过这3个批次,到最终生成代码。

TorchDynamo + FX Graph

Python 函数
    ↓
[TorchDynamo] 字节码分析 + 图捕获
    ↓ 生成
FX Graph + Guards
    ↓
┌──────────────────────────────────┐
│ 批次 1:FX Graph Passes(Lowering 前)                    │
│  - 操作对象:FX Graph(Python 级别的 Graph/Node)          │
│  - 算子粒度:ATen/Prims ops                              │
│  - 操作目标:减少节点数,增加Meta信息,清晰依赖关系            │
├──────────────────────────────────┤
│  ├─ 阶段1:Pre-Grad Passes                             │
│  │   • 结构规范化(split/concat/reshape 消除)            │
│  │   • 形状传播(ShapeProp/FakeTensorProp,添加Meta信息)  │
│  │   • Padding 调整                                     │
│  │   • 消除冗余clone                                     │
│  │   • 拓扑排序(确保是拓扑序)                            │
│  │                                                     │
│  ├─ 阶段2:AOT Autograd(训练阶段生成 Joint Graph)       │
│  │   • 追踪前向+反向                                      │
│  │   • 应用分解表,生成可微分的低层算子,训练时生成反向图       │
│  │   • Backward管理精度:使用混合精度来加速和节省内存         │
│  │   • torch/_decomp/decompositions.py                 │
│  │                                                     │
│  ├─ 阶段3:Joint Graph Passes(全局视角)                │
│  │   • Peephole 优化(常量折叠、冗余视图消除transpose)     │
│  │   • 随机数处理(Dropout,不保存mask tensor,只保存随机数)│
│  │                                                     │
│  ├─ 阶段4:Min-Cut Partitioning                         │
│  │   • 分析 saved tensors,最小化需要保存的中间张量         │
│  │   • 图拆分为 fw_graph + bw_graph                      │
│  │                                                      │
│  └─ 阶段5:Post-Grad Passes                             │
│      • 局部性重排(reorder_for_locality)                  │
│      • No-op 消除                                        │
│      • Reinplace(减少内存分配)                           │
│      • 分布式优化(分布式通信算子:多次通信融合,计算通信重叠)    │
│      • FSDP优化(预取,内存压力检测,异步通信正确性保证,硬件特化)│
└───────────────────────────────────┘

TorchInductor + Inductor IR

                       FX Graph
                          ↓ 前向图 + 后向图
┌──────────────────────────────────┐
│   转换点:Lowering(FX Graph → Inductor IR)             │
│  - torch/_inductor/lowering.py                          │
│  - 每个 ATen op → Inductor IR 节点                       │
│  - 高层 → 低层:aten.add 一 Pointwise                    │
└──────────────────────────────────┘
                          ↓
┌──────────────────────────────────┐
│ 批次 2:Inductor IR Passes(Lowering 后,Codegen 前)     │
│  - 操作对象:Inductor IR(TensorBox/Buffer/Pointwise)    │
│  - 这批 passes 主要在 Scheduler 中                        │
├──────────────────────────────────┤
│  ├─ 依赖分析(Dependency Analysis)                     │
│  │   • 构建 IR 节点间的数据依赖图:强依赖RAW,弱依赖WAR/WAW  │
│  │   • _inductor/scheduler.py: compute_dependencies()  │
│  │   • 确定拓扑排序                                      │
│  │   • _inductor/scheduler.py: topological_sort_schedule() │
│  │   • 循环依赖检测                                      │
│  │   • _inductor/memory.py: validate_graph_acyclic()   │
│  │                                                     │
│  ├─ 融合决策(Fusion)                                   │
│  │   • Pointwise 逐元素操作,融合                          │
│  │   • Reduction 规约操作,融合策略:Persistent/Split-K    │
│  │   • ExternKernel 调用外部库                            │
│  │   • _inductor/scheduler.py: can_fuse()               │
│  │                                                      │
│  ├─ 内存规划(Memory Planning)                          │
│  │   • Buffer 生命周期分析                                │
│  │   • 内存复用 Inplace                                  │
│  │   • 峰值内存估算与重排                                  │
│  │   • memory.py                                        │
│  │                                                      │
│  └─ 布局优化(Layout Transformation)                    │
│      • Stride 重排                                       │
│      • Layout 传播                                       │
│      • Channels-last 转换                                │
└──────────────────────────────────┘
                          ↓
┌──────────────────────────────────┐
│ 批次 3:Codegen Passes(代码生成时的优化)                  │
│  - 操作对象:Triton/C++ AST                              │
│  - 基于Jinja2填充模板,生成设备特定优化                     │
│  - torch/_inductor/codegen/                            │
├──────────────────────────────────┤
│  ├─ Triton Codegen(GPU)                              │
│  │   • Tiling 策略选择                                  │
│  │   • Block size 优化 512-2048:太大,寄存器和共享内存不够用;太小,内核启动开销大 │
│  │   • Autotune(多配置选择):同个kernel可多版本自动调优    │
│  │   • _inductor/codegen/triton.py                     │
│  │   • @triton_heuristics.pointwise 装饰器              │
│  │   • @triton.jit triton jit 装饰器 + Kernel 主体       │
│  │                                                     │
│  └─ C++ Codegen(CPU)                                 │
│      • SIMD 向量化                                       │
│      • OpenMP 并行化                                     │
│      • codegen/cpp.py                                   │
└──────────────────────────────────┘
    ↓
设备内核 (Triton/C++)  triton编译器或g++编译后,通过PyBind加载回Python
    ↓
[Runtime] 守卫检查 + 缓存管理 + 重编译触发
    ↓
执行

TorchDynamo

Guards的目的是为了保障优化后的代码能正确执行。在构建 FX Graph 的过程中,TorchDynamo 会记录所有的假设(Tensor类型或shape等),这些假设会被编译成"守卫"代码,每次调用编译后的函数前,先运行这个守卫函数。如果返回 False,说明假设被违反,需要重编译。

FX Graph的对象是python对象,方便开发者直接打印与调试。包括Graph,Node,GraphModule。
1️⃣ Node:图中的一个操作,包括五种操作类型:placeholder输入参数;call_function调用函数;call_method调用方法;get_attr获取属性;output返回值。
2️⃣ Graph:节点的有序容器,维护了一个节点的双向链表,支持高效的插入、删除、遍历操作。
3️⃣ GraphModule:可执行的图,继承自 nn.Module,把一个 Graph 包装成可以直接调用的模块。
从 Python 函数生成 FX Graph 的过程,核心机制是用 Proxy 对象来拦截所有操作,即符号追踪,只对操作进行追踪,而不执行。追踪过程中,通过 Fake Tensor 记录形状和类型等信息,但不占用真实的内存。
FX Graph 支持各种子图匹配与替换,可以通过SubgraphMatcher实现。

Runtime 机制

1)守卫系统:编译时,Dynamo 生成所有guards假设;运行时,guards假设被验证。
2)缓存策略:变体管理,一个函数可能有多个编译版本,对应不同的输入特征;
3)重编译触发与动态形状泛化:形状、设备、dtype 等变化触发重编译

第一次调用(冷启动):
Frame Evaluation:Dynamo 拦截函数调用
Graph Capture:分析字节码,构建 FX Graph + Guard
Guard Generation:记录 x 的所有属性
AOT Autograd(如果是训练模式):生成前后向图
IR Lowering:FX → Core ATen → Prims
Inductor:
    算子分解:relu, add
    融合分析:两个 pointwise 可以融合
    代码生成:生成 Triton kernel
    编译:调用 Triton 编译器
Cache:存储编译结果
Execute:运行生成的 kernel
Return:返回结果

第二次调用(热路径):
Guard Check:检查 x 的属性(< 1 微秒)
Cache Hit:找到编译版本
Execute:直接运行 kernel
Return:返回结果

形状变化时(重发重编译):
Guard Check:失败(形状不匹配)
Recompilation:
    尝试动态形状泛化
    重新编译(约 1-5 秒)
Cache Update:存储新版本
Execute:运行新 kernel
Return:返回结果

尝试案例

import torch
import torch.nn as nn

# 启用调试模式
import os
os.environ['TORCH_COMPILE_DEBUG'] = '1'

# 定义模型
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(4, 4))
    
    def forward(self, x):
        mm = torch.mm(x, self.weight)
        return torch.relu(mm)

# 编译
model = Model()
compiled = torch.compile(model, backend="inductor")

# 运行(触发编译)
x = torch.randn(4, 4, requires_grad=True)
output = compiled(x)
loss = output.sum()
loss.backward()

# 查看生成的文件
# 在 /tmp/torchinductor_<username>/ 目录下
# fx_graph_readable.py - 各阶段的 FX Graph
# output_code.py - 生成的 Triton/C++ 代码

参考文献

PyTorch 2.x 编译系统详解:https://mp.weixin.qq.com/s/kqhkyOgk45adnHHFw7O_GQ

posted @ 2025-12-04 14:35  qccz123456  阅读(16)  评论(0)    收藏  举报