Can Group Relative Policy Optimization Improve Thai Legal Reasoning and Question Answering? 论文复现
要复现 “Can Group Relative Policy Optimization Improve Thai Legal Reasoning and Question Answering?” (以下简称 “GRPO - 泰国法律 QA 论文”),需围绕 “泰国法律 QA 任务特性” 与 “GRPO 算法适配” 两大核心,分任务准备、数据构建、基线模型训练、GRPO 强化学习优化、评估验证五大步骤展开。以下是详细复现方案(注:若论文未开源,需基于 GRPO 通用逻辑与泰国法律 QA 任务特性补全关键细节):
一、复现前提与核心假设
由于论文未公开代码 / 数据,复现需基于以下合理假设(需根据论文原文调整,若有公开资源优先对齐):
- 任务定义:泰国法律 QA 任务为 “基于泰国法律条文(如《刑法》《民法典》),对法律问题生成准确回答 / 推理过程”,需评估回答的法律准确性、条文相关性、逻辑完整性。
- GRPO 核心适配点:GRPO(Group Relative Policy Optimization)通过 “组内相对优势” 优化策略,适合解决 “多候选答案排序” 场景 —— 泰国法律 QA 中,可生成多个候选回答,通过奖励模型区分优劣,再用 GRPO 更新策略。
- 基础模型:采用支持泰语的预训练模型(如泰语版 Llama-2、mT5-large、ThaiBERT)作为 Actor/Critic 基线。
- 工具依赖:使用
transformers
(模型加载)、verl
/rlhf
(GRPO 实现)、datasets
(数据处理)、scikit-learn
(评估指标计算)。
二、Step 1:任务与数据准备
1. 泰国法律 QA 数据集构建(核心难点)
泰国法律数据稀缺,需通过 “公开资源爬取 + 人工标注” 构建数据集,参考论文常见数据规模(1k~5k 样本):
- 数据来源:
- 公开法律数据库:泰国司法部官网(www.moj.go.th)、泰国法律在线库(Thai Legal Information Institute)爬取法律条文(如《刑法典》《民事诉讼法》);
- 法律问答样本:泰国律师论坛(如 Pantip Law 板块)、法院判例摘要(简化为 “问题 - 标准答案” 对)。
- 数据格式(单样本示例):
json
{ "question": "泰国刑法中,盗窃金额超过10万泰铢的量刑标准是什么?", # 泰语问题 "legal_context": "《泰国刑法典》第334条:盗窃金额超过10万泰铢的,处3年以上10年以下有期徒刑,并处罚金...", # 相关法律条文 "ground_truth": "根据《泰国刑法典》第334条,盗窃金额超过10万泰铢的,量刑为3年以上10年以下有期徒刑,并处罚金不超过20万泰铢。", # 标准答案 "negative_examples": [ # 劣质候选(用于奖励模型训练) "盗窃金额超过10万泰铢处1-3年有期徒刑(错误:量刑范围偏差)", "根据《民法典》第123条,盗窃需赔偿损失(错误:引用法律条文错误)" ] }
- 数据预处理:
- 泰语文本清洗:去除特殊符号、统一字体编码(避免泰语字符乱码);
- 候选生成:对每个问题,用基线模型生成 3~5 个候选回答(含 1 个优质、2~4 个劣质,劣质可通过 “故意错引条文”“简化逻辑” 生成)。
2. 数据集划分
- 训练集:70%(用于 SFT、奖励模型训练、GRPO 优化);
- 验证集:20%(超参数调优,如 GRPO 的
clip_param
、学习率); - 测试集:10%(最终评估,避免数据泄露)。
三、Step 2:基线模型训练(SFT 阶段)
先通过监督微调(SFT) 训练 “泰国法律 QA 基线模型”,为后续 GRPO 提供初始策略:
1. 模型选择与配置
- 基础模型:优先选择泰语优化模型,如
facebook/mbart-large-50-many-to-many-mmt
(多语言支持)、airesearch/wangchanberta-base-att-spm-uncased
(泰国本土预训练模型),或泰语版 Llama-2(需自定义权重转换); - 训练参数:
python运行
training_args = TrainingArguments( output_dir="./thai_legal_sft_model", per_device_train_batch_size=4, gradient_accumulation_steps=4, # 适配显存(泰语模型通常参数量较大) learning_rate=2e-5, num_train_epochs=3, logging_steps=10, fp16=True, # 用混合精度训练节省显存 push_to_hub=False )
2. 微调任务设计
- 输入格式:
"法律条文:{legal_context}\n问题:{question}\n回答:"
(引导模型结合条文生成回答); - 标签:
ground_truth
(标准答案); - 训练目标:最小化交叉熵损失,让模型学习 “法律条文→问题→准确回答” 的映射。
3. 基线模型评估
用BLEU-4、ROUGE-L、法律准确性评分(人工标注或规则匹配,如 “条文引用正确性”)评估 SFT 模型,确保基线性能达标(如 BLEU-4≥30,法律准确性≥60%),否则需扩大数据集或调整模型。
四、Step 3:奖励模型训练(RM 阶段)
GRPO 需依赖奖励模型(Reward Model)判断回答优劣,因此需训练一个 “能区分泰国法律 QA 回答质量” 的 RM:
1. 奖励模型架构
- 基础模型:复用 SFT 阶段的泰语模型(如 WangchanBERTa),头部替换为 “线性层 + Sigmoid”(输出 0~1 的奖励分数,1 表示最优,0 表示最差);
- 输入:
"法律条文:{legal_context}\n问题:{question}\n回答:{candidate}"
(候选回答); - 输出:奖励分数(反映回答的法律准确性、相关性)。
2. 训练数据构建( pairwise 偏好数据)
- 对每个样本,从 “优质候选(ground_truth 或 SFT 生成的高优回答)” 和 “劣质候选(negative_examples)” 中随机抽取 1 对,构成
(candidate_pos, candidate_neg)
对; - 训练目标:最小化 “优质候选分数 < 劣质候选分数” 的概率,采用对比损失:
python运行
def reward_loss(logits_pos, logits_neg): # logits_pos:优质候选的奖励分数,logits_neg:劣质候选的奖励分数 return -torch.log(torch.sigmoid(logits_pos - logits_neg) + 1e-8).mean()
3. 奖励模型评估
- 评估指标:偏好准确率(模型给优质候选的分数 > 劣质候选的比例,目标≥85%);
- 验证方式:在验证集随机抽取 100 对
(pos, neg)
,计算模型判断正确的比例,确保 RM 能有效区分优劣。
五、Step 4:GRPO 算法优化(RL 阶段)
这是复现核心,需将 GRPO 与泰国法律 QA 任务结合,通过奖励模型反馈优化 SFT 基线模型:
1. GRPO 核心原理回顾
GRPO 通过 “组内相对优势”(Group Relative Advantage)优化策略,避免 PPO 中 “单个样本优势偏差” 的问题,适合 “多候选排序” 场景 —— 在泰国法律 QA 中,可将 “同一问题的多个候选回答” 视为一个 “组”,计算组内相对优势后更新策略。
2. 关键参数配置(参考 GRPO 论文与泰国法律 QA 特性)
参数 | 建议值 | 说明 |
---|---|---|
batch_size |
8~16 | 组大小(每组含同一问题的 3~5 个候选回答) |
clip_param |
0.1~0.2 | GRPO 剪辑阈值(避免策略更新幅度过大) |
learning_rate |
1e-6~5e-6 | 低于 SFT(RL 阶段需缓慢更新) |
entropy_coef |
0.01~0.05 | 熵正则(鼓励模型探索不同法律推理路径) |
num_rollout_steps |
10~20 | 每轮 GRPO 的采样步数(生成候选回答的次数) |
group_size |
3~5 | 每组候选回答数量(越多越能体现 “相对优势”) |
3. GRPO 训练流程(基于verl
框架实现,若论文用自定义框架需适配)
python运行
from verl.algorithms.grpo import GRPOConfig, GRPOAlgorithm
from verl.core.actor_critic import ActorCriticModel
# 1. 加载SFT基线模型作为Actor/Critic初始权重
actor_critic = ActorCriticModel(
actor_config=dict(model_name_or_path="./thai_legal_sft_model"),
critic_config=dict(model_name_or_path="./thai_legal_sft_model"), # Critic复用Actor结构,输出价值估计
reward_model_config=dict(model_name_or_path="./thai_legal_rm_model") # 加载训练好的奖励模型
)
# 2. 配置GRPO参数
grpo_config = GRPOConfig(
batch_size=16,
clip_param=0.2,
learning_rate=3e-6,
entropy_coef=0.03,
num_rollout_steps=15,
group_size=4 # 每组4个候选回答
)
# 3. 初始化GRPO算法
grpo_alg = GRPOAlgorithm(
config=grpo_config,
actor_critic=actor_critic,
train_dataset=thai_legal_train_dataset, # 训练集(含问题、法律条文)
val_dataset=thai_legal_val_dataset
)
# 4. 启动GRPO训练
grpo_alg.train(
num_train_epochs=5,
output_dir="./thai_legal_grpo_model",
logging_steps=5
)
4. 关键细节适配(泰国法律 QA 特性)
- 组内候选生成:对每个问题,用 Actor 模型生成
group_size
个候选回答(通过调整temperature=0.7
增加多样性,避免候选重复); - 奖励计算:将每个候选回答输入奖励模型,得到 0~1 的分数,作为 GRPO 的 “原始奖励”;
- 相对优势计算:在每组内,将每个候选的奖励减去组内平均奖励,得到 “相对优势”,用于策略梯度更新(核心是 GRPO 与 PPO 的差异点);
- 法律条文约束:在生成候选回答时,强制模型引用至少 1 条相关法律条文(通过输入格式约束或输出正则检查),避免 “无依据推理”。
六、Step 5:评估与验证(复现论文核心结论)
需对齐论文的评估指标,验证 GRPO 是否比基线模型(SFT)、传统 RL 算法(如 PPO)在泰国法律 QA 任务上有提升:
1. 核心评估指标
- 法律准确性(最重要):
- 人工标注:邀请法律专业人员(或熟悉泰国法律的标注者)对回答的 “条文引用正确性”“量刑 / 条款解读准确性” 打分(1~5 分);
- 规则匹配:通过关键词匹配(如 “《刑法典》第 334 条” 是否在回答中,且与问题相关)自动计算准确率。
- QA 任务通用指标:BLEU-4、ROUGE-L(衡量回答与标准答案的文本相似度)、困惑度(Perplexity,衡量回答的流畅性)。
- RL-specific 指标:策略熵(反映探索能力)、奖励分数均值(GRPO 训练过程中奖励是否持续上升)。
2. 对比实验设计(需复现论文中的基线)
至少包含以下对比组,验证 GRPO 的优势:
模型 | 描述 |
---|---|
SFT Baseline | 仅经过监督微调的基线模型 |
PPO + RM | 用 PPO(传统 RL)+ 同一奖励模型优化的模型 |
GRPO(论文方法) | 用 GRPO + 同一奖励模型优化的模型(核心实验组) |
Human(上限) | 人类法律专家的回答(作为性能上限) |
3. 预期结果(参考 GRPO 通用优势与法律 QA 特性)
- GRPO 在法律准确性上比 SFT 提升 15%~25%,比 PPO 提升 8%~15%(因 GRPO 的相对优势更能区分 “细微法律差异”,如相似条文的准确引用);
- GRPO 的策略熵高于 PPO,说明其探索能力更强,能覆盖更多法律推理场景;
- 奖励分数均值在 GRPO 训练过程中持续上升,且最终稳定在 0.7~0.8(SFT 约 0.5,PPO 约 0.6)。
七、复现难点与解决方案
- 泰国法律数据稀缺:
- 解决方案:用 “数据增强” 生成更多样本(如对法律条文进行同义改写、对问题进行 paraphrase),或复用多语言法律数据(如英文法律 QA 数据翻译为泰语,再人工修正)。
- GRPO 算法实现细节:
- 解决方案:若论文未开源,可基于
verl
(字节跳动 RLHF 框架,支持 GRPO)或trl
(Hugging Face RL 库)的 GRPO 接口,调整参数适配泰国法律 QA;若需自定义,参考 GRPO 论文核心公式(组相对优势计算、损失函数)。
- 解决方案:若论文未开源,可基于
- 泰语模型显存限制:
- 解决方案:采用模型并行(如
device_map="auto"
)、混合精度训练(FP16/FP8)、梯度累积,或选择小参数量泰语模型(如 WangchanBERTa-base 而非 large)。
- 解决方案:采用模型并行(如
八、复现 checklist(确保完整性)
- 确认泰国法律 QA 数据集的格式与论文一致(含问题、法律条文、候选回答);
- 完成 SFT 基线模型训练,评估指标达标;
- 训练奖励模型,偏好准确率≥85%;
- 配置 GRPO 参数(参考论文或上述建议值),启动训练;
- 完成对比实验(SFT、PPO、GRPO),记录所有指标;
- 验证 GRPO 在法律准确性上的提升是否符合论文结论;
- 排查训练中的异常(如 NaN 损失、奖励不上升),调整超参数。
通过以上步骤,可系统复现论文核心实验,重点验证 GRPO 在 “泰国法律 QA 任务中对推理准确性、条文引用正确性” 的提升效果。
论文:https://modelscope.cn/papers/2507.09638
本文来自博客园,作者:limingqi,转载请注明原文链接:https://www.cnblogs.com/limingqi/p/19083340