Can Group Relative Policy Optimization Improve Thai Legal Reasoning and Question Answering? 论文复现

要复现 “Can Group Relative Policy Optimization Improve Thai Legal Reasoning and Question Answering?” (以下简称 “GRPO - 泰国法律 QA 论文”),需围绕 “泰国法律 QA 任务特性” 与 “GRPO 算法适配” 两大核心,分任务准备、数据构建、基线模型训练、GRPO 强化学习优化、评估验证五大步骤展开。以下是详细复现方案(注:若论文未开源,需基于 GRPO 通用逻辑与泰国法律 QA 任务特性补全关键细节):

一、复现前提与核心假设

由于论文未公开代码 / 数据,复现需基于以下合理假设(需根据论文原文调整,若有公开资源优先对齐):
  1. 任务定义:泰国法律 QA 任务为 “基于泰国法律条文(如《刑法》《民法典》),对法律问题生成准确回答 / 推理过程”,需评估回答的法律准确性、条文相关性、逻辑完整性。
  2. GRPO 核心适配点:GRPO(Group Relative Policy Optimization)通过 “组内相对优势” 优化策略,适合解决 “多候选答案排序” 场景 —— 泰国法律 QA 中,可生成多个候选回答,通过奖励模型区分优劣,再用 GRPO 更新策略。
  3. 基础模型:采用支持泰语的预训练模型(如泰语版 Llama-2、mT5-large、ThaiBERT)作为 Actor/Critic 基线。
  4. 工具依赖:使用transformers(模型加载)、verl/rlhf(GRPO 实现)、datasets(数据处理)、scikit-learn(评估指标计算)。

二、Step 1:任务与数据准备

1. 泰国法律 QA 数据集构建(核心难点)

泰国法律数据稀缺,需通过 “公开资源爬取 + 人工标注” 构建数据集,参考论文常见数据规模(1k~5k 样本):
  • 数据来源:
    • 公开法律数据库:泰国司法部官网(www.moj.go.th)、泰国法律在线库(Thai Legal Information Institute)爬取法律条文(如《刑法典》《民事诉讼法》);
    • 法律问答样本:泰国律师论坛(如 Pantip Law 板块)、法院判例摘要(简化为 “问题 - 标准答案” 对)。
  • 数据格式(单样本示例):
    json
    {
      "question": "泰国刑法中,盗窃金额超过10万泰铢的量刑标准是什么?",  # 泰语问题
      "legal_context": "《泰国刑法典》第334条:盗窃金额超过10万泰铢的,处3年以上10年以下有期徒刑,并处罚金...",  # 相关法律条文
      "ground_truth": "根据《泰国刑法典》第334条,盗窃金额超过10万泰铢的,量刑为3年以上10年以下有期徒刑,并处罚金不超过20万泰铢。",  # 标准答案
      "negative_examples": [  # 劣质候选(用于奖励模型训练)
        "盗窃金额超过10万泰铢处1-3年有期徒刑(错误:量刑范围偏差)",
        "根据《民法典》第123条,盗窃需赔偿损失(错误:引用法律条文错误)"
      ]
    }
    
     
  • 数据预处理:
    • 泰语文本清洗:去除特殊符号、统一字体编码(避免泰语字符乱码);
    • 候选生成:对每个问题,用基线模型生成 3~5 个候选回答(含 1 个优质、2~4 个劣质,劣质可通过 “故意错引条文”“简化逻辑” 生成)。

2. 数据集划分

  • 训练集:70%(用于 SFT、奖励模型训练、GRPO 优化);
  • 验证集:20%(超参数调优,如 GRPO 的clip_param、学习率);
  • 测试集:10%(最终评估,避免数据泄露)。

三、Step 2:基线模型训练(SFT 阶段)

先通过监督微调(SFT) 训练 “泰国法律 QA 基线模型”,为后续 GRPO 提供初始策略:

1. 模型选择与配置

  • 基础模型:优先选择泰语优化模型,如 facebook/mbart-large-50-many-to-many-mmt(多语言支持)、airesearch/wangchanberta-base-att-spm-uncased(泰国本土预训练模型),或泰语版 Llama-2(需自定义权重转换);
  • 训练参数:
    python运行
    training_args = TrainingArguments(
        output_dir="./thai_legal_sft_model",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,  # 适配显存(泰语模型通常参数量较大)
        learning_rate=2e-5,
        num_train_epochs=3,
        logging_steps=10,
        fp16=True,  # 用混合精度训练节省显存
        push_to_hub=False
    )

2. 微调任务设计

  • 输入格式:"法律条文:{legal_context}\n问题:{question}\n回答:"(引导模型结合条文生成回答);
  • 标签:ground_truth(标准答案);
  • 训练目标:最小化交叉熵损失,让模型学习 “法律条文→问题→准确回答” 的映射。

3. 基线模型评估

用BLEU-4、ROUGE-L、法律准确性评分(人工标注或规则匹配,如 “条文引用正确性”)评估 SFT 模型,确保基线性能达标(如 BLEU-4≥30,法律准确性≥60%),否则需扩大数据集或调整模型。

四、Step 3:奖励模型训练(RM 阶段)

GRPO 需依赖奖励模型(Reward Model)判断回答优劣,因此需训练一个 “能区分泰国法律 QA 回答质量” 的 RM:

1. 奖励模型架构

  • 基础模型:复用 SFT 阶段的泰语模型(如 WangchanBERTa),头部替换为 “线性层 + Sigmoid”(输出 0~1 的奖励分数,1 表示最优,0 表示最差);
  • 输入:"法律条文:{legal_context}\n问题:{question}\n回答:{candidate}"(候选回答);
  • 输出:奖励分数(反映回答的法律准确性、相关性)。

2. 训练数据构建( pairwise 偏好数据)

  • 对每个样本,从 “优质候选(ground_truth 或 SFT 生成的高优回答)” 和 “劣质候选(negative_examples)” 中随机抽取 1 对,构成(candidate_pos, candidate_neg)对;
  • 训练目标:最小化 “优质候选分数 < 劣质候选分数” 的概率,采用对比损失:
    python运行
    def reward_loss(logits_pos, logits_neg):
        # logits_pos:优质候选的奖励分数,logits_neg:劣质候选的奖励分数
        return -torch.log(torch.sigmoid(logits_pos - logits_neg) + 1e-8).mean()
    
     

3. 奖励模型评估

  • 评估指标:偏好准确率(模型给优质候选的分数 > 劣质候选的比例,目标≥85%);
  • 验证方式:在验证集随机抽取 100 对(pos, neg),计算模型判断正确的比例,确保 RM 能有效区分优劣。

五、Step 4:GRPO 算法优化(RL 阶段)

这是复现核心,需将 GRPO 与泰国法律 QA 任务结合,通过奖励模型反馈优化 SFT 基线模型:

1. GRPO 核心原理回顾

GRPO 通过 “组内相对优势”(Group Relative Advantage)优化策略,避免 PPO 中 “单个样本优势偏差” 的问题,适合 “多候选排序” 场景 —— 在泰国法律 QA 中,可将 “同一问题的多个候选回答” 视为一个 “组”,计算组内相对优势后更新策略。

2. 关键参数配置(参考 GRPO 论文与泰国法律 QA 特性)

参数建议值说明
batch_size 8~16 组大小(每组含同一问题的 3~5 个候选回答)
clip_param 0.1~0.2 GRPO 剪辑阈值(避免策略更新幅度过大)
learning_rate 1e-6~5e-6 低于 SFT(RL 阶段需缓慢更新)
entropy_coef 0.01~0.05 熵正则(鼓励模型探索不同法律推理路径)
num_rollout_steps 10~20 每轮 GRPO 的采样步数(生成候选回答的次数)
group_size 3~5 每组候选回答数量(越多越能体现 “相对优势”)

3. GRPO 训练流程(基于verl框架实现,若论文用自定义框架需适配)

python运行
from verl.algorithms.grpo import GRPOConfig, GRPOAlgorithm
from verl.core.actor_critic import ActorCriticModel

# 1. 加载SFT基线模型作为Actor/Critic初始权重
actor_critic = ActorCriticModel(
    actor_config=dict(model_name_or_path="./thai_legal_sft_model"),
    critic_config=dict(model_name_or_path="./thai_legal_sft_model"),  # Critic复用Actor结构,输出价值估计
    reward_model_config=dict(model_name_or_path="./thai_legal_rm_model")  # 加载训练好的奖励模型
)

# 2. 配置GRPO参数
grpo_config = GRPOConfig(
    batch_size=16,
    clip_param=0.2,
    learning_rate=3e-6,
    entropy_coef=0.03,
    num_rollout_steps=15,
    group_size=4  # 每组4个候选回答
)

# 3. 初始化GRPO算法
grpo_alg = GRPOAlgorithm(
    config=grpo_config,
    actor_critic=actor_critic,
    train_dataset=thai_legal_train_dataset,  # 训练集(含问题、法律条文)
    val_dataset=thai_legal_val_dataset
)

# 4. 启动GRPO训练
grpo_alg.train(
    num_train_epochs=5,
    output_dir="./thai_legal_grpo_model",
    logging_steps=5
)
 

4. 关键细节适配(泰国法律 QA 特性)

  • 组内候选生成:对每个问题,用 Actor 模型生成group_size个候选回答(通过调整temperature=0.7增加多样性,避免候选重复);
  • 奖励计算:将每个候选回答输入奖励模型,得到 0~1 的分数,作为 GRPO 的 “原始奖励”;
  • 相对优势计算:在每组内,将每个候选的奖励减去组内平均奖励,得到 “相对优势”,用于策略梯度更新(核心是 GRPO 与 PPO 的差异点);
  • 法律条文约束:在生成候选回答时,强制模型引用至少 1 条相关法律条文(通过输入格式约束或输出正则检查),避免 “无依据推理”。

六、Step 5:评估与验证(复现论文核心结论)

需对齐论文的评估指标,验证 GRPO 是否比基线模型(SFT)、传统 RL 算法(如 PPO)在泰国法律 QA 任务上有提升:

1. 核心评估指标

  • 法律准确性(最重要):
    • 人工标注:邀请法律专业人员(或熟悉泰国法律的标注者)对回答的 “条文引用正确性”“量刑 / 条款解读准确性” 打分(1~5 分);
    • 规则匹配:通过关键词匹配(如 “《刑法典》第 334 条” 是否在回答中,且与问题相关)自动计算准确率。
  • QA 任务通用指标:BLEU-4、ROUGE-L(衡量回答与标准答案的文本相似度)、困惑度(Perplexity,衡量回答的流畅性)。
  • RL-specific 指标:策略熵(反映探索能力)、奖励分数均值(GRPO 训练过程中奖励是否持续上升)。

2. 对比实验设计(需复现论文中的基线)

至少包含以下对比组,验证 GRPO 的优势:

模型描述
SFT Baseline 仅经过监督微调的基线模型
PPO + RM 用 PPO(传统 RL)+ 同一奖励模型优化的模型
GRPO(论文方法) 用 GRPO + 同一奖励模型优化的模型(核心实验组)
Human(上限) 人类法律专家的回答(作为性能上限)

3. 预期结果(参考 GRPO 通用优势与法律 QA 特性)

  • GRPO 在法律准确性上比 SFT 提升 15%~25%,比 PPO 提升 8%~15%(因 GRPO 的相对优势更能区分 “细微法律差异”,如相似条文的准确引用);
  • GRPO 的策略熵高于 PPO,说明其探索能力更强,能覆盖更多法律推理场景;
  • 奖励分数均值在 GRPO 训练过程中持续上升,且最终稳定在 0.7~0.8(SFT 约 0.5,PPO 约 0.6)。

七、复现难点与解决方案

  1. 泰国法律数据稀缺:
    • 解决方案:用 “数据增强” 生成更多样本(如对法律条文进行同义改写、对问题进行 paraphrase),或复用多语言法律数据(如英文法律 QA 数据翻译为泰语,再人工修正)。
  2. GRPO 算法实现细节:
    • 解决方案:若论文未开源,可基于verl(字节跳动 RLHF 框架,支持 GRPO)或trl(Hugging Face RL 库)的 GRPO 接口,调整参数适配泰国法律 QA;若需自定义,参考 GRPO 论文核心公式(组相对优势计算、损失函数)。
  3. 泰语模型显存限制:
    • 解决方案:采用模型并行(如device_map="auto")、混合精度训练(FP16/FP8)、梯度累积,或选择小参数量泰语模型(如 WangchanBERTa-base 而非 large)。

八、复现 checklist(确保完整性)

  1.  确认泰国法律 QA 数据集的格式与论文一致(含问题、法律条文、候选回答);
  2.  完成 SFT 基线模型训练,评估指标达标;
  3.  训练奖励模型,偏好准确率≥85%;
  4.  配置 GRPO 参数(参考论文或上述建议值),启动训练;
  5.  完成对比实验(SFT、PPO、GRPO),记录所有指标;
  6.  验证 GRPO 在法律准确性上的提升是否符合论文结论;
  7.  排查训练中的异常(如 NaN 损失、奖励不上升),调整超参数。

 

通过以上步骤,可系统复现论文核心实验,重点验证 GRPO 在 “泰国法律 QA 任务中对推理准确性、条文引用正确性” 的提升效果。
论文:https://modelscope.cn/papers/2507.09638

image

image

 

posted on 2025-09-10 11:20  limingqi  阅读(1)  评论(0)    收藏  举报

导航