ML-SYS 学习宝典:从 RLHF 系统到 SGLang 深入解析
项目简介
Awesome-ML-SYS-Tutorial 是一个专注于机器学习系统(ML SYS)领域的学习笔记与代码仓库。它旨在为对 ML 与系统交叉领域感兴趣的研究者和工程师提供高质量的学习资源。项目内容涵盖了从基础理论(如强化学习、马尔科夫决策过程)到前沿系统框架(如 RLHF 训练系统、SGLang 推理引擎)的深度解析,并包含了大量实战配置指南和核心代码走读。
该项目的核心价值在于其实践性与系统性:不仅提供了如何配置开发环境、使用 Docker、管理依赖等实操指南,还深入剖析了 slime、verl、OpenRLHF 等主流开源框架的架构设计与核心实现,帮助读者理解大规模机器学习训练与推理背后的系统设计思想。
功能特性
- 全面的 RLHF 系统开发笔记:详细记录了 veRL、slime、OpenRLHF 等框架的设计理念、工作流程、核心模块(如 Rollout、Training、Buffer)及异步训练、多轮对话、工具调用等高级特性的实现。
- 深入的 SGLang 推理引擎解析:从请求生命周期、KV Cache 管理、分布式并行(TP/DP)、权重更新机制到多模态(如 Qwen2.5-VL)和扩散模型支持,进行了全方位的代码走读和原理阐述。
- 训推不一致的系统性解决方案:探讨了 RL 训练中因数值精度、算子差异导致的训练与推理不匹配问题,并介绍了 slime 框架中“真正 On-Policy 训练”和“算法缓解(TIS/MIS)”两种解决方案。
- 实用的环境配置与工具指南:提供了基于 Docker 的可复现环境配置方法、高效的 bash/zsh 配置、uv 包管理工具的使用,以及如何在复杂集群上配置开发环境。
- 核心算法与代码实现:包含 PPO、GRPO、SPIN、Online DPO 等主流强化学习算法的理论推导、公式解析及其在具体框架(如 verl、trl)中的代码实现。
- 前沿工作复现与分析:对 Search-R1、LUFFY、Kimi K1.5 等前沿研究工作的算法思想和实现细节进行了学习和复现笔记。
安装指南
本项目主要为学习笔记和代码分析,不依赖于单一的安装脚本。但项目内包含了大量环境配置的实践指导:
- 基础环境:推荐使用 Docker 来创建隔离且可复现的开发环境。可以参考项目中的 Docker 配置指南(如使用
lmsysorg/sglang:latest或nvcr.io/nvidia/pytorch等基础镜像)。 - Python 环境:建议使用 uv 作为快速的 Python 包管理器,并搭配虚拟环境(venv)。
# 创建虚拟环境 python3 -m venv ~/.python/myenv source ~/.python/myenv/bin/activate # 安装 uv python3 -m pip install uv - 框架安装:针对不同的学习模块,需要安装对应的框架。
- verl (with SGLang):
git clone https://github.com/volcengine/verl.git cd verl python3 -m uv pip install -e ".[sglang]" --prerelease=allow - slime:
git clone https://github.com/THUDM/slime.git cd slime pip install -e . - SGLang:
git clone https://github.com/sgl-project/sglang.git cd sglang pip install -e "python[all]"
- verl (with SGLang):
- 依赖管理:注意处理 PyTorch、CUDA、flash-attn、transformers 等依赖的版本兼容性问题,具体版本需参考各框架的官方要求。
使用说明
快速开始:运行一个 RLHF 训练示例
以下以在 verl 框架中使用 SGLang 运行一个 GSM8K 多轮 GRPO 训练为例:
- 准备环境和数据:
# 拉取并预处理数据集 cd verl python examples/data_preprocess/gsm8k_multiturn_w_tool.py - 启动训练脚本:
该脚本内部会配置 Ray 集群、启动 SGLang 推理引擎,并开始 GRPO 训练循环。# 设置 GPU 并运行训练(示例脚本) export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh
核心概念:理解 SGLang 中的请求处理流程
SGLang 作为高性能推理引擎,其核心是高效处理并发的推理请求。一个请求的生命周期大致如下(简化):
# 伪代码,示意 SGLang 内部流程
# 1. 请求接收与 Tokenization
# 用户请求通过 FastAPI endpoint 进入
async def v1_chat_completions(request: ChatCompletionRequest):
req_input = convert_to_generate_req_input(request)
# TokenizerManager 进行分词和多模态数据处理
tokenized = tokenizer_manager.tokenize(req_input)
# 将请求放入 Scheduler 的等待队列
scheduler.add_request(tokenized)
# 2. 调度与批处理 (Scheduler)
# Scheduler 事件循环
while True:
# 从等待队列中根据优先级(如最长前缀)选取请求组成批次
batch = scheduler.get_next_batch()
if batch.prefill_tokens > 0:
# Prefill 阶段:处理输入提示,填充 KV Cache
run_prefill_batch(batch)
# Decode 阶段:自回归生成 token
next_token_ids = run_decode_batch(batch)
# 处理生成结果,更新请求状态
scheduler.process_batch_results(batch, next_token_ids)
# 如果请求完成,将结果发送给 DetokenizerManager
for req in batch.finished_requests:
detokenizer_manager.send_output(req)
# 3. 结果返回
# DetokenizerManager 将 token ID 解码为文本,通过 HTTP 返回给用户。
API 概览:verl 训练配置关键参数
在 verl 等框架中,训练通过配置文件(如 YAML)或命令行参数控制。以下是一些关键参数:
| 参数组 | 参数名 | 说明 |
|---|---|---|
| Data | train_batch_size |
每次训练迭代使用的提示(Prompt)数量。 |
| Data | max_prompt_length |
提示的最大 token 长度。 |
| Actor/Rollout | ppo_mini_batch_size |
PPO 训练中,将经验数据分割成的 mini-batch 大小。 |
| Actor/Rollout | rollout.n |
每个提示采样多少条回复(Responses)。 |
| Rollout Engine | rollout.name |
指定推理引擎,如 sglang 或 vllm。 |
| Multi-turn | rollout.multi_turn.enable |
是否启用多轮对话训练。 |
| Tool Calling | rollout.trace.backend |
启用轨迹追踪(如 weave),用于分析工具调用。 |
核心代码
1. slime 中 GAE 的 Chunk-Scan 并行计算
该优化解决了长序列下 GAE 计算串行导致的性能瓶颈。核心思想是将时间序列分块,并行计算局部 GAE,再通过前缀扫描合并。
# 代码片段位于 slime 相关 PR (#850)
# 核心思想:将 GAE 计算转化为可并行的前缀扫描问题
def chunk_scan_gae(full_rewards, full_values, gamma, lambd, chunk_size):
"""
full_rewards: [B, T]
full_values: [B, T+1]
"""
B, T = full_rewards.shape
num_chunks = (T + chunk_size - 1) // chunk_size
# 1. 将数据分块
reward_chunks = full_rewards.split(chunk_size, dim=1)
value_chunks = full_values.split(chunk_size, dim=1)
# 2. 并行计算每个 chunk 的局部 delta 和 GAE(伪代码)
# 每个 chunk 内部是串行的,但 chunk 之间可以并行处理
chunk_results = []
for i in range(num_chunks):
chunk_rewards = reward_chunks[i]
chunk_values = value_chunks[i]
next_values = value_chunks[i+1] if i+1 < len(value_chunks) else 0.0
# 计算该 chunk 的 delta 和局部 GAE
local_delta = chunk_rewards + gamma * next_values - chunk_values
local_gae = compute_sequential_gae(local_delta, gamma*lambd) # 内部串行
chunk_results.append((local_delta, local_gae))
# 3. 前缀扫描合并 chunk (简化示意)
full_advantages = torch.zeros_like(full_rewards)
carry = 0.0 # 跨 chunk 的累积因子
for i in range(num_chunks):
local_delta, local_gae = chunk_results[i]
# 将上一个 chunk 的尾部影响加到当前 chunk 的 GAE 上
adjusted_gae = local_gae + carry
# 更新 carry 用于下一个 chunk
carry = (gamma * lambd) ** chunk_size * local_gae[:, -1:]
# 存储结果
start_idx = i * chunk_size
end_idx = start_idx + chunk_size
full_advantages[:, start_idx:end_idx] = adjusted_gae
return full_advantages
代码注释:
- 传统 GAE 计算需要对时间步
t从T-1到0进行串行循环,无法利用 GPU 并行能力。 chunk_scan_gae函数首先将长度为T的序列划分为多个chunk。- 每个
chunk内部的 GAE 计算仍是串行的,但不同的chunk可以并行计算,这显著提高了计算吞吐。 - 之后,通过一个轻量的“前缀扫描”步骤,将前一个
chunk的末端 GAE 值(carry)传播到后一个chunk,从而合并得到整个序列正确的 GAE。 - 该优化在超长序列的 Agentic RL 场景下,可带来 100-300 倍 的加速。
2. SGLang 中多模态请求的 Token 扩展与特征注入
以 Qwen2.5-VL 为例,展示了 SGLang 如何处理包含图像的请求。
# 代码思想基于 sglang/runtime/multimodal_extensions/qwen_vl.py
# 关键步骤:Tokenizer 扩展与 M-RoPE 位置编码
def process_qwen_vl_request(generate_req_input):
"""
generate_req_input: 包含 text 和 image_data 的请求输入
"""
text = generate_req_input.text
image_data_list = generate_req_input.image_data
# 1. 并发加载和预处理图像
pixel_values_list = []
for img_data in image_data_list:
# 加载图像,并应用模型特定的 resize (如 smart_resize)
pixel_values = load_and_preprocess_image(img_data)
pixel_values_list.append(pixel_values)
# 2. Tokenization 与即时 Token 扩展
# 原始 prompt 可能包含类似 `<|vision_start|><image><|vision_end|>` 的占位符
# Tokenizer 会直接将 `<image>` 替换为一连串特定的 image placeholder tokens (如 <|image_pad|>)
input_ids = tokenizer.encode(text)
# 此时 input_ids 中已经包含了代表图像区域的特殊 token 序列
# 3. 计算 M-RoPE (Multimodal Rotary Position Embedding) 位置
# 这为图像 token 和文本 token 提供了统一的、精确的位置信息
mrope_positions = compute_mrope_positions(input_ids, pixel_values_list)
# 4. 构建多模态数据项
mm_items = []
for pv in pixel_values_list:
mm_items.append(MultimodalDataItem(pixel_values=pv))
# 5. 返回给调度器
return {
"input_ids": input_ids,
"mm_items": mm_items,
"mrope_positions": mrope_positions
}
# 在模型前向传播时 (如 qwen2_5_vl.py)
def forward_in_model(input_ids, mrope_positions, mm_items):
# 获取常规的词嵌入(包含特殊 token)
embeddings = word_embedding(input_ids)
# 应用 RoPE,位置信息由 mrope_positions 提供
embeddings = apply_rotary_pos_emb(embeddings, mrope_positions)
# 识别出 input_ids 中图像占位符 token 的位置
image_token_indices = find_image_token_indices(input_ids)
# 获取视觉特征
image_features = vision_transformer(mm_items.pixel_values)
# 将视觉特征投影到语言模型嵌入空间
projected_image_features = project_to_lm_dim(image_features)
# 将视觉特征注入到对应位置的 embeddings 中
embeddings[image_token_indices] = projected_image_features
# 后续送入 LLM backbone 进行计算
logits = llm_backbone(embeddings)
return logits
代码注释:
- 多模态请求处理的关键在于Token 扩展和特征注入。
- Token 扩展发生在 Tokenizer 阶段,图像占位符被直接替换为一系列预定义的特殊 Token,这使得输入序列在结构上包含了图像信息。
- M-RoPE 为融合后的序列(文本 Token + 图像占位 Token)计算统一的位置编码,确保注意力机制能理解各部分的位置关系。
- 在模型前向传播时,视觉特征由视觉编码器(如 Vision Transformer)提取,并投影到语言模型的嵌入维度,然后精确地替换嵌入层中对应图像占位符位置的向量。
- 这种设计使得图像信息能够“无缝”地融入语言模型的推理流程,同时保持了 SGLang 调度和缓存机制的有效性(如图片缓存基于 pixel values 的哈希)。
3. 训推不一致的算法缓解:Masked Importance Sampling (MIS)
在 slime 框架中,除了追求比特级对齐的“真正 On-Policy”方案,还提供了基于重要性采样的算法缓解方案。
# 代码思想基于 slime 中训推不一致缓解的实现
# Masked Importance Sampling (MIS) 核心公式应用
def compute_mis_corrected_loss(actor_log_probs, # 训练引擎计算的 log prob
rollout_log_probs, # 推理引擎 rollout 时计算的 log prob
advantages, # 优势函数估计
clip_epsilon=0.2,
mismatch_threshold=0.1): # 失配阈值
"""
通过重要性采样权重来修正策略梯度,抑制训推差异过大的样本。
"""
# 1. 计算重要性权重 (importance weight)
# 即 rollout 策略与当前训练策略的概率比
log_ratio = rollout_log_probs - actor_log_probs
ratio = torch.exp(log_ratio)
# 2. 计算原始的 PPO 裁剪损失
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
original_pg_loss = -torch.min(surr1, surr2).mean()
# 3. 计算训推差异度量(例如,每个 token 的 KL 散度)
kl_per_token = actor_log_probs - rollout_log_probs # 近似 KL
avg_kl = kl_per_token.mean(dim=-1) # 序列平均 KL
# 4. 构造 Mask (基于阈值)
# 差异过大的样本,其梯度会被抑制
mask = (avg_kl < mismatch_threshold).float().unsqueeze(-1) # 扩张维度以匹配 token 级
# 5. 应用 MIS:将 mask 作为权重乘到原始损失上
# 或者更精细地,调整重要性权重
mis_ratio = ratio * mask
surr1_mis = mis_ratio * advantages
surr2_mis = torch.clamp(mis_ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
mis_corrected_loss = -torch.min(surr1_mis, surr2_mis).mean()
# 也可以选择将 mask 直接作用于损失
# mis_corrected_loss = (original_pg_loss * mask.mean())
return mis_corrected_loss, avg_kl, mask
代码注释:
- 训推不一致 指即使模型权重相同,训练引擎和推理引擎计算出的 token 对数概率也存在微小差异,本质上是浮点运算顺序等系统原因造成的异策略(off-policy)效应。
- MIS 核心思想:识别出训推差异(可用 KL 散度度量)过大的样本,并在策略梯度更新时降低这些样本的权重(
mask)。 ratio是重要性采样权重,衡量 rollout 时策略与当前训练策略的偏离程度。mask是一个 0/1 掩码,当序列的平均 KL 低于阈值时,mask=1,否则为0。这相当于过滤掉了差异过大的“不可靠”样本。- 将
mask应用到ratio或直接应用到损失上,可以抑制因训推不一致而产生的有害梯度,从而提高训练稳定性,尤其是在 MoE 模型或长序列任务中。 - 这种方法是一种效率与正确性之间的折中,相比于实现比特级对齐,它的开销更小,但能有效缓解不匹配带来的负面影响。
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)
公众号二维码

公众号二维码


浙公网安备 33010602011771号