SFT模仿,RL泛化
SFT-vs-RL:基础模型后训练中泛化与记忆的比较研究
- 论文核心概念🔍
这篇论文的核心问题是:在基础模型(如大型语言模型和视觉语言模型)的后训练阶段,监督微调(SFT)和强化学习(RL)两种技术如何影响模型的泛化能力和记忆倾向。泛化能力指的是模型将学到的知识应用到新任务或数据变体的能力,而记忆指的是模型简单复制训练数据中的模式。
论文的关键发现是:
• 问题:SFT 和 RL 都被广泛用于后训练,但之前不清楚它们对泛化的具体影响。SFT 可能只是记忆训练数据,导致在分布外(OOD)数据上表现差;而 RL 可能通过学习通用原则来提升泛化。
• 解决思路:论文通过设计新任务(如 GeneralPoints)和利用现有任务(如 V-IRL),在文本和视觉环境中系统比较 SFT 和 RL。方法上,采用多步 RL 框架,结合顺序修订和基于结果的奖励,以评估模型在规则变化和视觉变化下的表现。
• 核心 Insight:RL(尤其是使用基于结果的奖励时)在文本和视觉任务中都能有效泛化到未见过的变体,而 SFT 倾向于记忆训练数据,泛化能力弱。此外,SFT 对 RL 训练有稳定作用,但 RL 本身能提升模型的潜在视觉识别能力。
- 论文方法🔬
2.1 过去方法的问题
在基础模型的后训练中,SFT 和 RL 是常用技术,但过去的研究存在局限性:
• SFT 的局限性:SFT 通过在特定任务数据上微调模型,使模型适应下游任务。先前工作(如 FLAN 和 LIMA)表明 SFT 能提升零样本性能并调整输出格式,但它可能过度拟合训练数据,导致在 OOD 数据上表现下降。论文指出,SFT 更像“格式教师”,但容易记忆而非学习可转移的原理。
• RL 的局限性:RL 通常用于对齐人类偏好或优化特定任务(如 RLHF),但过去研究多关注单一模态(如仅文本或仅视觉),或仅比较一种后训练方法。缺乏在多模态任务中系统比较 SFT 和 RL 对泛化影响的工作。
• Motivation:因此,论文动机是填补这一空白,通过多模态任务(文本和视觉)来区分记忆和泛化,并揭示 RL 在复杂环境中的优势。
2.2 整体框架
论文的方法框架包括任务设计、强化学习设置和训练流程。整体目标是评估 SFT 和 RL 在泛化方面的表现,框架基于多步 RL 和顺序修订。
2.2.1 任务设计
论文使用两个主要任务来评估泛化:
• GeneralPoints:一个原创的算术推理卡牌游戏,类似于 24 点游戏。模型接收四张卡牌(以文本或图像形式),需使用每张卡的数字恰好一次计算目标数字(默认 24)。任务变体包括规则变化(如卡牌 J、Q、K 的数字解释不同)和视觉变化(如卡牌颜色不同)。
• V-IRL:一个现实世界导航任务,专注于空间推理。模型需根据语言指令和视觉观察(街景图像)导航到目标位置。任务变体包括动作空间变化(如绝对方向 vs. 相对方向)和视觉变化(如不同城市的地标)。
这些任务允许测试模型在文本规则和视觉输入下的 OOD 泛化。
2.2.2 强化学习框架
论文采用多步 RL 框架,类似于 Zhai et al. (2024a) 和顺序修订公式(Snell et al., 2024)。框架将模型视为策略网络,使用 PPO 算法进行优化。
• 基本 RL 术语:
• 状态空间 \mathcal{S}:对于视觉语言模型(VLM),\mathcal{S} := \mathcal{V}^m \times \mathcal{O},其中 \mathcal{V}^m 是输入文本空间(最大词元长度 m),\mathcal{O} 是图像空间。对于纯语言模型(LLM),\mathcal{S} := \mathcal{V}^m。
• 动作空间 \mathcal{A}:\mathcal{A} := \mathcal{V}^n,表示输出文本空间(最大词元长度 n)。
• 奖励函数 r: \mathcal{S} \times \mathcal{A} \rightarrow \mathbb{R}:基于验证器(VER)的输出,生成奖励。
• 目标:学习策略 \pi: \mathcal{S} \rightarrow \mathcal{A} 以最大化期望回报 \max_{\pi \in \Pi} \mathbb{E}{\pi} \left[ \sum^{T} r_t \right],其中 (r_t = r(s_t, a_t)),T 是最大步数。
• 验证器(VER):
• 验证器评估模型输出 v^{\text{out}},并返回基于结果的奖励 r 和文本信息 v{\text{ver}}。数学上,(\text{VER}(v_t{\text{out}}) \mapsto (r_t, v_t^{\text{ver}}))。
• 例如,在 GeneralPoints 中,验证器检查生成的等式是否正确;在 V-IRL 中,检查动作是否匹配专家轨迹。
• 顺序修订公式:
• 这是多步 RL 的核心,允许模型在多个时间步中修订输出。状态转移基于提示的更新:
◦ 在时间步 t=0,初始输入 v_0^{\text{in}} 包含系统提示。
◦ 对于 t \geq 1,输入 v_t^{\text{in}} 是系统提示与所有先前模型输出和验证器输出的连结:\(v_t^{\text{in}} = \text{concat}(v_0^{\text{in}}, [v_k^{\text{out}}, v_k^{\text{ver}}]_{k=0}^{t-1})\)。
• 这使模型能根据反馈迭代改进答案。流程如图 2 所示。
• 策略网络和训练:
• 模型(如 Llama-3.2-Vision-11B)作为策略网络 \pi_\theta,参数为 \theta。
• 使用 PPO(Proximal Policy Optimization)算法更新策略,最大化回报。训练流程先进行 SFT 初始化,然后运行 RL。
• 奖励设计示例:
◦ GeneralPoints:正确等式奖励 +5,错误等式惩罚 -1 到 -3。
◦ V-IRL:正确动作奖励 +1,错误动作惩罚 -1。
2.2.3 训练流程细节
• SFT 阶段:使用任务特定的监督数据(如专家轨迹)微调模型,使模型适应任务格式。
• RL 阶段:从 SFT 初始化后的模型开始,通过多步 RL 进行训练。关键点是:
• 状态 s_t 由当前提示和视觉输入(如有)构成。
• 动作 a_t 是模型生成的文本响应。
• 奖励 r_t 由验证器基于结果给出。
• 公式上,回报最大化问题可分解为:
\[
\max_{\pi} \mathbb{E}_{\pi} \left[ \sum_{t=0}^{T} r_t \right]
\]
其中,PPO 算法通过策略梯度方法优化 \pi_\theta,具体包括:
◦ 收集经验轨迹。
◦ 计算优势函数。
◦ 更新策略参数 \theta 以提升期望回报。
• 计算估计:训练计算量(FLOPs)估计为 X_{\text{train}} = 6ND_{\text{train}},其中 N 是模型参数数量,D_{\text{train}} 是训练词元数。RL 额外有推理计算 X_{\text{inference}} = 2ND_{\text{buffer}}。
整个框架确保了模型能在多轮交互中学习泛化知识,而非简单记忆。
- 实验结果与分析📊
3.1 实验结果
论文在 GeneralPoints 和 V-IRL 任务上比较了 SFT 和 RL 的性能,包括分布内(ID)和分布外(OOD)评估。
• Baseline:
• 初始模型(Init):未经过后训练的基础模型(Llama-3.2-Vision-11B)。
• SFT:监督微调后的模型。
• RL:从 SFT 初始化后,通过多步 RL 训练的模型。
• 此外,与先前最先进方法(如 Yang et al., 2024a 的 V-IRL 基准)比较。
• 数据集:
• GeneralPoints:原创卡牌游戏数据集,从标准扑克牌中抽取四张牌保证有解。包括:
◦ 文本变体(GP-L):卡牌以文本描述。
◦ 视觉变体(GP-VL):卡牌以图像呈现。
◦ 规则变化:如 J、Q、K 解释为 10(ID)或 11、12、13(OOD)。
◦ 视觉变化:如黑色花色训练,红色花色测试(OOD)。
• V-IRL:现实世界导航数据集,基于 Yang et al. (2024a)。包括:
◦ 文本变体(V-IRL-L):纯语言描述。
◦ 视觉变体(V-IRL-VL):包含视觉输入(街景图像)。
◦ 规则变化:绝对方向动作空间(ID)vs. 相对方向动作空间(OOD)。
◦ 视觉变化:在纽约市训练,在全球多个城市测试(OOD)。
• 数据集大小:GeneralPoints 使用采样数据;V-IRL 使用 1000 条纽约路线训练,VLN 小基准(18 条全球路线)测试。
• 评估指标:
• GeneralPoints:成功率(%),即模型在回合中生成正确等式的比例。
• V-IRL:
◦ 每步准确率(%):模型在每个导航步骤中选择正确动作的比例。
◦ 整体成功率(%):模型在整个路径上所有步骤都正确的比例(更严格)。
• 统计处理:使用 Savitzky-Golay 滤波器平滑曲线,误差条基于二项分布近似。
关键结果摘要:
• RL 在 OOD 规则变化和视觉变化下均表现更好,例如在 GP-VL 上 OOD 性能提升 +3.0%,在 V-IRL-VL 上提升 +9.3%。
• SFT 在 OOD 评估中性能下降,如 GP-L 上下降 -8.1%,表明记忆倾向。
• RL 还提升了视觉识别准确率,作为泛化的副产品。
• 多步验证迭代(如 10 步)进一步改善泛化。
实验结果图表示例(如 Figure 5 和 Figure 6)显示了性能趋势,RL 在扩展训练计算时持续提升,而 SFT 饱和或下降。
RL 促进泛化,SFT 倾向于记忆:论文发现,在基础模型的后训练阶段,强化学习(RL)能够帮助模型学习可泛化的知识,使其在未见过的任务变体(如规则变化或视觉变化)上表现良好。相比之下,监督微调(SFT)更容易记忆训练数据,导致在分布外(OOD)数据上性能下降。这意味着 RL 更擅长捕捉通用原则,而 SFT 可能只是过度拟合训练样本。
RL 提升视觉识别能力:作为一个副产品,RL 训练还能提高视觉语言模型(VLM)的视觉识别准确率。例如,在视觉任务中,RL 帮助模型更好地识别图像中的关键元素(如卡牌数字或地标),从而增强整体泛化能力。
SFT 对 RL 训练有稳定作用:尽管 RL 在泛化方面更优,但论文表明 SFT 是有效 RL 训练的基础。SFT 通过调整模型输出格式(如确保响应符合 JSON 结构),使模型能更好地遵循指令,从而为后续的 RL 训练提供稳定的起点。没有 SFT 初始化,RL 可能因模型指令跟随能力差而失败。
这些结果验证了 RL 在泛化方面的优势,并强调了后训练方法选择的重要性。

浙公网安备 33010602011771号