O1复盘
简单复盘
o1 和 r1 之间间隔了四个月,这段时间里,rule-based reward
并没有被主流技术方案所认可。我们不妨做个简单的复盘,去思考下在那探索的四个月中,为什么大家更青睐于 prm / mcts
路线?为什么沿着这条路线做不出来突破?以及到底有哪些关键点是当时所被忽略的?
base model and data distribution are all you need
在复现 r1 的工作中,qwen
和 llama
展现出了不同的趋势,qwen-base、qwen-instruct、qwen-math-base 也展现出了不同的趋势。rl 本来就是一个发掘模型潜力的方法,如果模型没有潜力,那还发掘什么呢?
同理,数据分布也一样,目前开源的复现 r1 工作已经很多了,但也不是说任何工作都能和它们技术报告一样完美复现。orz 的 57K 数据,就是一个整理的比较好的数据,训练较为稳定。对待数据,我们要像 kimi1.5 报告中所说的,观测很多统计指标:每个 prompt 的多次采样的平均准确率、平均输出长度、是否不经过 cot 能直接说出答案等,甚至 prompt 的 ground truth 是否易于 verifier 进行判别(或提前统计好,或在训练过程中利用课程学习的思想动态调控)。
过去的时间内,小作坊团队们还是过于聚焦在“厨艺”上了,对“食材”的分析反倒是有所欠缺。
scaling is all you need
“结果正确”就是比“过程正确 + 结果正确”训出来的模型效果好,即使是现在也没人敢打包票吧。
r1 的成功并不能宣判 prm 的死刑,它只能说明,100W 条 orm 数据 > 1W 条 prm 数据。换个角度,深度学习的一个发展范式就是“scaling + 雕花”:先大力出奇迹,再一点点缩小成本、优化细节。显然 orm 属于 scaling,prm 属于雕花,不是 prm 的技术路线有问题,而是对 prm 的投入应该放在 orm 已经完全吃透之后。
infra is all you need
和 llm-reasoning 这两个方向有什么区别吗?prompt 的难度和 response 的长度。
以往我们使用 rlhf 的场景主要是:安全问题、诗歌创作修复韵脚问题、简单的代码数学问题等,往往都是几百个 token 就搞定的任务。现在不一样了,模型的一条回复就达到了上万个 token,这会让显存压力和解码时长陡增。作为应对,vllm
或 sglang 已经成为 rlhf 框架的标配。最初我也认为 grpo 省掉 critic_model 这一点并不关键,现在看来我还是只考虑了算法和数据的视角,并没有充分理解到额外维护一个和 actor_model 相同规模的 critic_model,对训练框架的稳定性有多大的挑战。
当模型 size 和 response_len 逐渐增大,“训练效率和预防训练过程中莫名其妙的 OOM ”就是复现 r1 工作中最大的难点(对,就是莫名其妙,在程序断了之前,你不会知道明明限制了最大 response_len,为啥它还会 OOM)。截止目前,大模型的强化算法没有什么创新,grpo 也已经出现一年半了。如果说各公司复现 r1 的进度不太一致,那大概率是因为各公司的 rl-infra 水平参差不齐吧。
hyper-parameter is all you need
因为之前没有坚定 rule-based reward 这个路线,所以大多数团队不会花很多时间去调整学习率、kl 系数、回报折扣系数、n_samples、使用的强化算法 …… 仅凭几组参数就敲定实验结论。再加上算法同学们通常都对训练数据十分敏感,观察到“过程瞎说但结果蒙对了的 response”,就会更坚定的认为当时所使用的 orm 技术方案不行。
也在技术报告里提到说他们尝试过了 prm / mcts 的路线,只不过遇到了种种瓶颈,最终选择回归到 orm 的路线,然后“顺利”地做出了 zero 和 r1。
没有足够的技术自信和训练算力,确实不会下那么大的决心去调参,所以 r1 注定是要由精英团队率先做出来的。
好的流程远胜不靠谱的算法trick
llm的sft和rl,笔者认为,二者差别不大,sft是rl的一个特例(有一些文章做了类似的讨论),而rl则更好的利用了负样本。在dapo中,一个核心是dynamic-sampling,简单来说,根据当前模型在prompt的bon,动态决定采样budget,难prompt采样更多的sample,简单prompt则采样更少的prompt,以及过滤模型解决不了的hard-prompt或者easy-prompt。
在sft阶段,通常也会使用类似的策略做code/math等等的拒绝采样、筛选多样性response(embedding+聚类、长度)。从DAPO中可以看出,一个良好的pipline(online-dynamic-sampling)远胜于不靠谱的算法trick。
当做好sft后,从数据/task、response合成/采样、response挑选/打分方法等等,都有一个相对固定且运行良好的流程。把这个流程做到online,在replay-buffer
的数据构造中即可应用,配合对应的挑选/打分/筛选策略,便可将sft阶段积累的优秀流程直接迁移到online-rl。同时也需要replay-buffer和主代码解耦,做灵活的控制。
总之,能做好sft且pipline能够在线化运行的团队,做好online-rl只是算力和时间的问题(生产要素)。反之,则陷入一个窘境(生产关系):
- 做sft的一直offline调数据、蒸馏、挑选,但pipline较难在线化运行,且需要人力不断重复,但实际上都是well-defined流程和配比实验,不太需要过多的人工参与;(出现能力/任务冲突后,人工介入处理)
- 做rl的不断重复sft的数据流程:找数据、找replay-buffer的数据构建策略,踩过一坨坑后,发现,这些策略其实和sft并无不同,造成了极大的资源浪费和时间浪费。
- 做agent-rl的时候,agent-rl只需要写一个推理引擎的多次采样即可,而环境的稳定性则更为重要。如果sft没怎么做过agent-based的sft数据,则环境积累基本为0,当应用agent-rl的时候,环境稳定性会成为rl训练的阿喀琉斯之踵。尤其是agent环境,延时、返回结果的不确定性等等会加剧这个问题。
误区1:r1 的关键在于对 base 模型做强化
如果熟读 dpsk 的技术报告,应该会有印象:基于 base 模型做 rl 得到的 zero 模型,仅仅是在 r1 的训练过程中提供了一些 long cot 的启动数据罢了,这个冷启动数据还要经过数据专员的大幅度修正。也就是说,r1 的训练,是基于一个 sft model 进行的(1000 条冷启动数据训练)。
- 对 base model 直接做 rl,好处在于模型没有受到任何限制, explore 的空间极大,有发挥的空间,缺点在于模型的 follow 格式能力很差;
- 对 long cot model 做 rl,好处在于模型 follow 格式的能力很强,同时在 long cot sft 阶段,会被灌输很多正确的思考模式,缺点大概是模型的起始输出长度过长;
- 对 instruct model 做 rl,复现 r1 的效果应该是最差的,因为模型的思考模式已经有些固化了,explore 空间比较小。
三种 setting,孰优孰劣,我的认知也未必正确,大家自己在实验中自己找感觉吧。也可以花时间比较一下模型在 rl 之前、rl 之后的 response 有什么变化,统计一下最高频的 N-gram 看看。
误区2:r1 的复现在于看见 response_len 的稳定增长
在 response_len 这个指标上,很多人都有点倒果为因了,认为观察到 response_len 和 reward 一起上涨就代表模型的效果在变好。
复杂的问题,往往需要更高级的思考逻辑,而高级的思考逻辑往往会具有更长的长度。换言之,在不考虑 attention 衰减的情况下,response_len 几乎与 explore_space 呈正比关系 —— 所以我们并不是希望 response_len 上涨,而是希望模型 explore 到更有价值的或者是全新的思考逻辑。如果只是想增加 response_len,我们有一百种方法,把 eos_token 的 probability 调小一倍,开 dropout,加各种 noise 让模型训崩,诱导模型出现复读机现象 …… 有意义吗?显然没有。
reflection pattern 也是同理,不是说看见模型说出了 however、wait,就代表模型具有反思能力了。而是要看是否这些反思 pattern 帮助模型提高了准确率,或者说是,带反思 token 的 response 的 accuracy 是否真的高于 response 的平均 accuracy。此外,不同的 reflection pattern 对 accuracy 的贡献也不相同,try another approach 就是比 compute again 要高级一些,模型能不能在 reward 持续上涨的过程中,自发地提升优质 reflection pattern 的出现概率,也是我们要观察的重点。
误区3:打点太少,只关注 reward、loss、response_len
工欲善其事,必先利其器,打点就是做 rl 时最重要的“器”。毕竟打点又不增加训练耗时,能多记录信息就多记录,就算是把每条 response 的生成耗时都写到 tensorboard 里也不会让人觉得奇怪。
比如下面这些指标,我们都可以记录下来,万一就能从中找到些灵感呢?
- 模型的输出同质化指标:最常见的指标自然是 policy_model response 的 entropy。但也可以是别的,同一 prompt 下 N 条 response 的编辑距离,N-gram 重复比例等都可;
- 模型的各种 response_len:答对时候的 len,答错时候的 len,答对且有反思时候的 len,答错且有反思时候的 len 等;
- 模型的 ACC:不仅是整体的 ACC,也要有各种设定下的 ACC,比如单条 prompt 下的 ACC,出现反思 pattern 时的 ACC,高于平均 response_len 时的 ACC 等;
- 模型的异常现象占比:response 没有 follow 格式,response 超出预设长度了,response 出现复读机现象,或者是 repeat_score 过高,response 中英混杂等;
- 算法的异常现象占比:ppo / grpo 都有一个 clip 操作,记录 clip 发生的频率,到底是上溢多还是下溢多,溢出的比例随着实验推进会发生什么变化?
- ……
误区4:忘了实验目标是什么
在和朋友做 r1 的技术交流的时候,我发现大家都会时不时陷入到一个死胡同里去。具体来说,就是为了让某个指标好看而忽略了实验原本的目标是什么?最典型的死胡同莫过于:想方设法的让 policy_model response 的熵不要低于某个值。
- 为什么希望熵别太低呢?因为希望模型的 N 条 response 足够有区分度,信息量足够大;
- 那具体的做法是什么呢?加 entropy_loss,调 temperature,调 clip 的上下界;
- 有效果吗?可能有,但更可能把模型训崩溃;
- 分析模型为什么崩溃?可这怎么去分析呢,熵 loss 和 语言模型 loss 本来就是相反的目标,temperature 过大则会让各种牛鬼蛇神 token 都冒出来;
- 陷入痛苦,为什么我的“改动”让模型训崩溃了!
简单来说,模型 response 趋向于同质化是一个必然现象,rlhf 的目标不就是:把模型从 decode 10 次答对一次,训成 decode 3 次答对一次吗?把 path@N 的收益集中在 path@1 上。
如果模型没训崩溃,reward 和 response_len 持续增长,那就不要纠结 response 多样性是否过低的问题了。或者我们就直接换数据啊,response 多样性低大概率是 prompt 太难或太简单了,考虑用课程学习思想淘汰这条数据。这个问题上太纠结于设计 loss 很容易走火入魔。
总之,目前的 long cot 实验重点是两个目标:reward 稳定提高,response 探索出一些高级的思考 pattern。至于什么 response_len 持续增长,什么 aha moment,并不是实验的关键!要知道,o1 和 r1 目前更擅长的其实是 planning,而不是无休止的“wait、but、however”。
未来展望
数学底子好的人,就应该继续去优化 rl 算法,往大了搞可以提出一个新的算法,往小了搞就从理论上去证明 grpo / ppo 中 clip 公式、kl_loss、advantage_norm、batch_norm 等是否合理,是否有优化的空间。
卡多的人则可以搞排列组合,是否加 kl loss,是否直接丢弃超长的 response,是否应该给损失函数加熵,是否应该动态调 temperature,是否对 prompt 引入课程学习等等 —— 消融实验会证明一切。
infra 能力强的人,则应该持续优化训练框架的稳定性和训练效率,解决一些很底层但都被大家选择性忽略的问题,比如推理引擎 vllm 的结果和 model.forward() 其实还是有不少精度 diff 的。
前瞻性强的人,就该收一收手头的 math 工作,赶紧搞 code 了。AIME2024 显然已经走上了 GSM8K 的后尘,被大家玩的炉火纯青。除此之外,如何像 dpsk 一样把模型从 code / math 上学到的 long cot 能力,泛化到模型的通用能力上,亦是一个极具挑战的难题,也是 o1 / r1 真正牛的地方。
(白嫖的人,可以等千问团队发布 qwen3 和 qwq 的详细报告与 checkpoint 哈,感觉快来了)
ORZ 57K数据处理:
数量、多样性和质量。遵循这些关键方面,我们通过全面的收集和清理过程来管理我们的数据集:
- 我们从各种来源收集公共数据,包括AIME(截至2023年)、MATH[9]、Numina MATH集合[10]、Tulu3 MATH[11]、0penR1-MATH-220k[12]和其他开源数据集。根据来源和问题难度,我们检索AMC、AIME、Math、0lympiads和AoPS论坛组件作为我们的难度级别提示,以确保适当的难度级别。
- 我们使用程序化方法合成额外的推理任务来增强数据集。·我们排除了基于规则的奖励函数难以评估的问题,如多项选择题和面向证明的问题,确保了训练过程中奖励计算的准确性和一致性。我们实现了一种基于模型的过滤策略,该策略基于问题难度的启发式评估。具体来说,我们使用LLM来评估每个问题的通过率,删除通过率过高或为零的样本。
- 最终的精选数据由大约129k个跨越数学和推理领域的样本组成。该集合专门用于增强模型在复杂问题解决任务中的能力,仔细平衡数量、多样性和质量。
此外,我们利用32B模型本身的训练过程来识别具有挑战性和高质量的提示。我们最初使用从完整的129k样本数据集中采样的数据对32B模型进行1100步的训练。随后,我们确定了特别困难的提示,即模型在总共64次尝试中获得的正确答案少于4个,导致大约13k个具有挑战性的提示。然后,在100个额外步骤的最后训练阶段,有选择地使用这些识别的提示,旨在解决模型的薄弱环节,提高在困难推理任务上的性能。
Skywork-OR1数据与RL训练策略
为提升模型在数学和代码方面能力,Skywork-OR1构建了一个高质量数学和代码数据集。
团队设计了三个标准进行数据筛选:可验证性(Verifiable)、正确性(Correct)与挑战性(Challenging),剔除无法自动验证的证明类题目、有误题目、和缺少unit test的代码问题。
数学领域共计收集11万道题目,主要依赖NuminaMath-1.5(含约89.6万题),选用如AIME和Olympiads等较难子集,并补充了如DeepScaleR、Omni-MATH、AIME 1983-2023难题来源。
代码领域收集了13.7k条高质量代码问题,主要以LeetCode和TACO数据为主,保留了单元测试完整、验证通过的问题,并进行向量级语义去重。
在数据过滤部分,团队对每道题进行了多轮采样并验证答案,以避免“全对”或“全错”现象对策略学习无效——模型生成全部错误,无法提供有效的学习信号;“全对”意味着模型已完全掌握,继续学习会浪费计算资源。
并通过人类审核结合LLM自动判题机制,对语义不清、信息不全、格式错误或含有无关内容的项目进行清理。使用LLM-as-a-Judge剔除掉约1-2K道质量不达标的数学题。
其次在强化学习部分,Skywork-OR1使用GRPO(Group Relative Policy Optimization)进行训练,并引入一系列优化策略。
在训练时数据优化上,一方面采用双重过滤策略:
- 离线过滤:训练前使用待训练模型评估数据,剔除正确率为0或1的样本;
- 在线过滤:每个epoch动态移除上一轮已完全掌握的数据,确保模型持续面对有挑战性的内容。
另一方面使用拒绝采样(Rejection Sampling)进行更精细的实时筛选,在每个训练步骤中动态剔除当前训练步中采样正确率为0或1的样本。这样可以维持policy loss、entropy loss和KL loss的合理比例,防止非policy loss比重异常增加导致的训练不稳定。
在训练Pipeline优化上主要做了两方面的探索。
(1)多阶段训练(Multi Stage Training):从小窗口开始,逐步增加上下文长度(seq_len),可以促使模型在有限token内高效完成任务;随后逐步扩展窗口大小,迭代增加生成长度,使模型逐渐掌握更复杂的长链思维能力。实验证明,多阶段训练能显著缩短训练时间,同时完全保持模型的长度扩展能力。
(2)截断优势掩码(Truncated Advantage Mask):在多阶段训练初期,由于上下文窗口限制,复杂问题的回答可能被截断。因此团队研究了两种处理窗口限制下截断样本的策略Adv-Mask Before(计算优势前排除截断样本)和Adv-Mask After(计算后将截断样本优势置零)。证明即使不屏蔽截断样本,模型也能有效适应并迅速提升性能,也证明多阶段训练框架的鲁棒性。
此外,在强化学习训练中还要保障模型的探索能力。
团队进行了三方面探索。
第一,高温度采样。采用τ=1.0(高于常见的0.6)维持更高群组内多样性,既保证足够正确样本提供学习信号,又允许模型探索更广泛解决路径。
第二,提升内在训练多样性。通过精细数据过滤、增加批量大小和减少数据重复使用,可以从源头上防止模型过早优化到单一输出方向,同时也保持较高熵值,避免局部最优。
第三,自适应熵控制。只有在熵值低于阈值时才提供熵增加鼓励,设定目标熵值并动态调整损失系数,同时最小化对正常训练轨迹的干扰。
最后在保障强化学习训练的稳定性,团队对损失函数进行优化。
第一,移除KL损失。研究中发现即使基于高质量SFT模型训练,KL损失仍限制性能提升。因此,除特定阶段外,团队在所有公开发布的Skywork-OR1系列模型中均未使用KL损失项,这使模型能够更充分地探索和优化推理能力。
第二,token级策略损失。移除了策略损失中的长度归一化项,并将损失在训练批次内的所有token上进行平均,以提升优化过程的一致性与稳定性。
浙公网安备 33010602011771号