VERL-GRPO实现

VERL-GRPO 源码分析

脚本为verl v0.5.0中的快速开始脚本

# Tested successfully on the hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0 image.
# It outperforms the Qwen2 7B base model by two percentage points on the test set of GSM8K.

set -x

data_dir="/home/hadoop-aipnlp/dolphinfs_ssd_hadoop-aipnlp/EVA/liruihan05/projects/verl/data/gsm8k"
model_path="/home/hadoop-aipnlp/dolphinfs_ssd_hadoop-aipnlp/EVA/liruihan05/models/open-sources/Qwen/Qwen3-8B"

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=${data_dir}/train.parquet \
    data.val_files=${data_dir}/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen3-8B \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console"]' \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen3_8b_function_rm' \
    trainer.n_gpus_per_node=4 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@

RayPPOTrainer.fit()

    def fit(self):
        """
        The training loop of PPO.
        The driver process only need to call the compute functions of the worker group through RPC
        to construct the PPO dataflow.
        The light-weight advantage computation is done on the driver process.
        """
        breakpoint()  # ✅ 会执行:调试断点
        from omegaconf import OmegaConf  # ✅ 会执行:导入配置管理库

        from verl.utils.tracking import Tracking  # ✅ 会执行:导入追踪工具类

        # ✅ 会执行:初始化日志追踪器
        # logger: Tracking对象,用于记录训练指标和配置
        logger = Tracking(
            project_name=self.config.trainer.project_name,  # 'verl_grpo_example_gsm8k'
            experiment_name=self.config.trainer.experiment_name,  # 'qwen3_8b_function_rm'
            default_backend=self.config.trainer.logger,  # ['console']
            config=OmegaConf.to_container(self.config, resolve=True),
        )

        # ✅ 会执行:初始化全局步数计数器
        self.global_steps = 0

        # ✅ 会执行:在开始训练前加载检查点
        self._load_checkpoint()

        # ❓ 可能执行:取决于val_reward_fn是否配置
        # 脚本中配置了data.val_files,所以可能有验证函数
        # self.val_reward_fn: 可能为None或函数对象
        # self.config.trainer.get("val_before_train", True): 默认为True
        if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
            # ❓ 可能执行:如果val_reward_fn不为None
            val_metrics = self._validate()
            assert val_metrics, f"{val_metrics=}"
            pprint(f"Initial validation metrics: {val_metrics}")
            logger.log(data=val_metrics, step=self.global_steps)
            # ❌ 不会执行:脚本中没有配置val_only
            if self.config.trainer.get("val_only", False):
                return

        # ✅ 会执行:添加进度条
        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")

        # ✅ 会执行:从步数1开始
        self.global_steps += 1
        last_val_metrics = None
        self.max_steps_duration = 0

        # ✅ 会执行:外层循环,总共15个epoch(trainer.total_epochs=15)
        for epoch in range(self.config.trainer.total_epochs):
            # ✅ 会执行:内层循环,遍历训练数据加载器
            for batch_dict in self.train_dataloader:
                metrics = {}
                timing_raw = {}

                # ❌ 不会执行:脚本中没有配置profile_steps,所以do_profile=False
                do_profile = (
                    self.global_steps in self.config.trainer.profile_steps
                    if self.config.trainer.profile_steps is not None
                    else False
                )
                # ✅ 会执行:但do_profile=False,所以不会真正进行性能分析
                with marked_timer("start_profile", timing_raw):
                    self._start_profiling(do_profile)

                # ✅ 会执行:将batch字典转换为DataProto对象
                batch: DataProto = DataProto.from_single_dict(batch_dict)

                # ✅ 会执行:准备用于生成时需要弹出的键列表
                batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
                non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
                # ❓ 可能执行:取决于batch中是否包含这些键
                if "multi_modal_data" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("multi_modal_data")
                if "raw_prompt" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("raw_prompt")
                if "tools_kwargs" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("tools_kwargs")
                if "interaction_kwargs" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("interaction_kwargs")
                if "index" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("index")
                if "agent_name" in batch.non_tensor_batch:
                    non_tensor_batch_keys_to_pop.append("agent_name")

                # ✅ 会执行:从batch中弹出生成所需的键
                gen_batch = batch.pop(
                    batch_keys=batch_keys_to_pop,
                    non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
                )

                # ✅ 会执行:传递全局步数到trace
                gen_batch.meta_info["global_steps"] = self.global_steps
                # ✅ 会执行:重复batch,n=5(actor_rollout_ref.rollout.n=5)
                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)

                # ✅ 会执行:判断是否为最后一步
                is_last_step = self.global_steps >= self.total_training_steps

                with marked_timer("step", timing_raw):
                    # ✅ 会执行:生成阶段
                    with marked_timer("gen", timing_raw, color="red"):
                        # ✅ 会执行:脚本中没有配置async_rollout_mode,默认为False
                        if not self.async_rollout_mode:
                            # ✅ 会执行:同步模式生成
                            gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
                        else:
                            # ❌ 不会执行:async_rollout_mode=False
                            gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
                        timing_raw.update(gen_batch_output.meta_info["timing"])
                        gen_batch_output.meta_info.pop("timing", None)

                    # ❌ 不会执行:adv_estimator=grpo,不是REMAX
                    # 脚本配置:algorithm.adv_estimator=grpo
                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
                        # ❌ 不会执行:REMAX分支
                        with marked_timer("gen_max", timing_raw, color="purple"):
                            gen_baseline_batch = deepcopy(gen_batch)
                            gen_baseline_batch.meta_info["do_sample"] = False
                            if not self.async_rollout_mode:
                                gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
                            else:
                                gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
                            batch = batch.union(gen_baseline_output)
                            reward_baseline_tensor = self.reward_fn(batch)
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
                            batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
                            batch.batch["reward_baselines"] = reward_baseline_tensor
                            del gen_baseline_batch, gen_baseline_output

                    # ✅ 会执行:为每个样本生成唯一ID
                    batch.non_tensor_batch["uid"] = np.array(
                        [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
                    )
                    # ✅ 会执行:重复batch以对齐rollout中重复的响应,n=5
                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                    # ✅ 会执行:将生成的响应合并到batch中
                    batch = batch.union(gen_batch_output)

                    # ❓ 可能执行:如果batch中没有response_mask则计算
                    if "response_mask" not in batch.batch.keys():
                        # ❓ 可能执行:取决于是否已有response_mask
                        batch.batch["response_mask"] = compute_response_mask(batch)
                    # ❌ 不会执行:脚本中没有配置balance_batch,默认为False
                    if self.config.trainer.balance_batch:
                        self._balance_batch(batch, metrics=metrics)

                    # ✅ 会执行:计算全局有效token数量
                    batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

                    with marked_timer("reward", timing_raw, color="yellow"):
                        # ❌ 不会执行:GRPO不使用奖励模型(use_rm=False,因为Role.RewardModel不在role_worker_mapping中)
                        # 脚本中没有配置reward_model相关参数
                        if self.use_rm:
                            reward_tensor = self.rm_wg.compute_rm_score(batch)
                            batch = batch.union(reward_tensor)

                        # ❌ 不会执行:脚本中没有配置launch_reward_fn_async,默认为False
                        if self.config.reward_model.launch_reward_fn_async:
                            future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)
                        else:
                            # ✅ 会执行:同步计算奖励
                            reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)

                    # ✅ 会执行:重新计算旧的对数概率
                    with marked_timer("old_log_prob", timing_raw, color="blue"):
                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                        entropys = old_log_prob.batch["entropys"]
                        response_masks = batch.batch["response_mask"]
                        loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
                        entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
                        old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
                        metrics.update(old_log_prob_metrics)
                        old_log_prob.batch.pop("entropys")
                        batch = batch.union(old_log_prob)

                        # ❓ 可能执行:取决于rollout时是否保存了log_probs
                        # 如果rollout阶段保存了log_probs,则会执行此分支用于诊断
                        if "rollout_log_probs" in batch.batch.keys():
                            # ❓ 可能执行:如果存在rollout_log_probs
                            rollout_old_log_probs = batch.batch["rollout_log_probs"]
                            actor_old_log_probs = batch.batch["old_log_probs"]
                            attention_mask = batch.batch["attention_mask"]
                            responses = batch.batch["responses"]
                            response_length = responses.size(1)
                            response_mask = attention_mask[:, -response_length:]
                            rollout_probs = torch.exp(rollout_old_log_probs)
                            actor_probs = torch.exp(actor_old_log_probs)
                            rollout_probs_diff = torch.abs(rollout_probs - actor_probs)
                            rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())
                            rollout_probs_diff_max = torch.max(rollout_probs_diff)
                            rollout_probs_diff_mean = torch.mean(rollout_probs_diff)
                            rollout_probs_diff_std = torch.std(rollout_probs_diff)
                            metrics.update(
                                {
                                    "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(),
                                    "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(),
                                    "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(),
                                }
                            )

                    # ❓ 可能执行:取决于是否使用参考策略
                    # 脚本中配置了actor_rollout_ref.ref相关参数,所以很可能使用参考策略
                    # 如果use_reference_policy=True(Role.RefPolicy在role_worker_mapping中)
                    if self.use_reference_policy:
                        # ❓ 可能执行:如果使用参考策略
                        with marked_timer("ref", timing_raw, color="olive"):
                            # ❓ 可能执行:取决于ref_in_actor的值
                            # ref_in_actor = (lora_rank > 0),脚本中没有配置lora,所以ref_in_actor=False
                            if not self.ref_in_actor:
                                # ✅ 如果使用参考策略,会执行此分支(因为ref_in_actor=False)
                                ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                            else:
                                # ❌ 不会执行:ref_in_actor=False
                                ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
                            batch = batch.union(ref_log_prob)

                    # ❌ 不会执行:GRPO不使用critic(use_critic=False)
                    # 脚本配置:algorithm.adv_estimator=grpo
                    # 根据代码372-384行,GRPO在列表中,所以use_critic=False
                    if self.use_critic:
                        with marked_timer("values", timing_raw, color="cyan"):
                            values = self.critic_wg.compute_values(batch)
                            batch = batch.union(values)

                    with marked_timer("adv", timing_raw, color="brown"):
                        # ✅ 会执行:reward_extra_infos_dict类型注解
                        reward_extra_infos_dict: dict[str, list]
                        # ❌ 不会执行:launch_reward_fn_async=False
                        if self.config.reward_model.launch_reward_fn_async:
                            reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
                        # ✅ 会执行:将token级别的分数添加到batch中
                        batch.batch["token_level_scores"] = reward_tensor

                        # ❓ 可能执行:如果reward_extra_infos_dict不为空
                        if reward_extra_infos_dict:
                            batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})

                        # ❌ 不会执行:脚本配置:algorithm.use_kl_in_reward=False
                        if self.config.algorithm.use_kl_in_reward:
                            batch, kl_metrics = apply_kl_penalty(
                                batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
                            )
                            metrics.update(kl_metrics)
                        else:
                            # ✅ 会执行:直接将分数作为奖励(因为use_kl_in_reward=False)
                            batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]

                        # ✅ 会执行:计算优势函数
                        # norm_adv_by_std_in_grpo: 默认为True(GRPO优势归一化因子)
                        norm_adv_by_std_in_grpo = self.config.algorithm.get(
                            "norm_adv_by_std_in_grpo", True
                        )

                        # ✅ 会执行:使用GRPO计算优势(调用compute_grpo_outcome_advantage函数)
                        batch = compute_advantage(
                            batch,
                            adv_estimator=self.config.algorithm.adv_estimator,  # GRPO
                            gamma=self.config.algorithm.gamma,
                            lam=self.config.algorithm.lam,
                            num_repeat=self.config.actor_rollout_ref.rollout.n,  # 5
                            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
                            config=self.config.algorithm,
                        )

                    # ❌ 不会执行:GRPO不使用critic(use_critic=False)
                    if self.use_critic:
                        with marked_timer("update_critic", timing_raw, color="pink"):
                            critic_output = self.critic_wg.update_critic(batch)
                        critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
                        metrics.update(critic_output_metrics)

                    # ✅ 会执行:critic_warmup=0,所以条件总是满足(0 <= global_steps)
                    # 脚本配置:trainer.critic_warmup=0
                    if self.config.trainer.critic_warmup <= self.global_steps:
                        # ✅ 会执行:更新actor
                        with marked_timer("update_actor", timing_raw, color="red"):
                            batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
                            # 核心函数是(update_policy)
                            actor_output = self.actor_rollout_wg.update_actor(batch) 
                        actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
                        metrics.update(actor_output_metrics)

                    # ❌ 不会执行:脚本中没有配置rollout_data_dir,默认为None
                    rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
                    if rollout_data_dir:
                        with marked_timer("dump_rollout_generations", timing_raw, color="green"):
                            inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
                            outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
                            scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
                            if "request_id" in batch.non_tensor_batch:
                                reward_extra_infos_dict.setdefault(
                                    "request_id",
                                    batch.non_tensor_batch["request_id"].tolist(),
                                )
                            self._dump_generations(
                                inputs=inputs,
                                outputs=outputs,
                                scores=scores,
                                reward_extra_infos_dict=reward_extra_infos_dict,
                                dump_path=rollout_data_dir,
                            )

                    # ❓ 可能执行:取决于val_reward_fn是否配置,以及步数
                    # 脚本配置:trainer.test_freq=5,所以每5步会执行一次验证
                    if (
                        self.val_reward_fn is not None
                        and self.config.trainer.test_freq > 0  # 5 > 0,满足
                        and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)  # 每5步或最后一步
                    ):
                        # ❓ 可能执行:如果满足条件(每5步或最后一步,且val_reward_fn不为None)
                        with marked_timer("testing", timing_raw, color="green"):
                            val_metrics: dict = self._validate()
                            # ❓ 可能执行:如果是最后一步
                            if is_last_step:
                                last_val_metrics = val_metrics
                        metrics.update(val_metrics)

                    # ✅ 会执行:检查ESI是否即将到期
                    esi_close_to_expiration = should_save_ckpt_esi(
                        max_steps_duration=self.max_steps_duration,
                        redundant_time=self.config.trainer.esi_redundant_time,
                    )
                    # ✅ 会执行:检查保存检查点的条件
                    # 脚本配置:trainer.save_freq=20,所以每20步会保存一次
                    if self.config.trainer.save_freq > 0 and (  # 20 > 0,满足
                        is_last_step
                        or self.global_steps % self.config.trainer.save_freq == 0  # 每20步
                        or esi_close_to_expiration
                    ):
                        # ❓ 可能执行:如果ESI即将到期
                        if esi_close_to_expiration:
                            print("Force saving checkpoint: ESI instance expiration approaching.")
                        # ✅ 会执行:如果满足保存条件(每20步或最后一步或ESI到期)
                        with marked_timer("save_checkpoint", timing_raw, color="green"):
                            self._save_checkpoint()

                # ✅ 会执行:停止性能分析
                with marked_timer("stop_profile", timing_raw):
                    self._stop_profiling(do_profile)

                # ✅ 会执行:获取当前步骤的持续时间
                steps_duration = timing_raw["step"]
                self.max_steps_duration = max(self.max_steps_duration, steps_duration)

                # ✅ 会执行:更新训练指标
                metrics.update(
                    {
                        "training/global_step": self.global_steps,
                        "training/epoch": epoch,
                    }
                )
                # ✅ 会执行:收集数据指标(use_critic=False)
                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                # ✅ 会执行:收集计时指标
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                # ✅ 会执行:获取GPU数量
                n_gpus = self.resource_pool_manager.get_n_gpus()
                # ✅ 会执行:收集吞吐量指标
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))

                # ❓ 可能执行:取决于采样器类型
                if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
                    # ❓ 可能执行:如果使用课程学习采样器
                    self.train_dataloader.sampler.update(batch=batch)

                # ✅ 会执行:记录指标到日志系统
                logger.log(data=metrics, step=self.global_steps)

                # ✅ 会执行:更新进度条
                progress_bar.update(1)
                # ✅ 会执行:增加全局步数
                self.global_steps += 1

                # ❓ 可能执行:如果是最后一步
                if is_last_step:
                    # ❓ 可能执行:如果是最后一步
                    pprint(f"Final validation metrics: {last_val_metrics}")
                    progress_bar.close()
                    return

                # ❓ 可能执行:取决于数据集是否有on_batch_end方法
                if hasattr(self.train_dataset, "on_batch_end"):
                    # ❓ 可能执行:如果数据集有on_batch_end方法
                    self.train_dataset.on_batch_end(batch=batch)

优势计算

compute_grpo_outcome_advantage:

def compute_grpo_outcome_advantage(
    token_level_rewards: torch.Tensor,
    response_mask: torch.Tensor,
    index: np.ndarray,
    epsilon: float = 1e-6,
    norm_adv_by_std_in_grpo: bool = True,
    config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute advantage for GRPO, operating only on Outcome reward
    (with only one scalar reward for each response).

    Args:
        token_level_rewards: `(torch.Tensor)`
            shape is (bs, response_length)
        response_mask: `(torch.Tensor)`
            shape is (bs, response_length)
        index: `(np.ndarray)`
            index array for grouping
        epsilon: `(float)`
            small value to avoid division by zero
        norm_adv_by_std_in_grpo: `(bool)`
            whether to scale the GRPO advantage
        config: `(Optional[AlgoConfig])`
            algorithm configuration object

    Note:
        If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.
        If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).

    Returns:
        advantages: `(torch.Tensor)`
            shape is (bs, response_length)
        Returns: `(torch.Tensor)`
            shape is (bs, response_length)
    """
    # 步骤1: 计算每个样本的总奖励(对token维度求和)
    # scores: torch.Tensor,形状从 (bs, response_length) 变为 (bs,)
    # 将每个响应的所有token奖励求和,得到每个响应的总奖励分数
    scores = token_level_rewards.sum(dim=-1)

    # 步骤2: 初始化用于存储分组信息的字典
    # id2score: defaultdict(list),键为index(uid),值为该组所有样本的分数列表
    # id2mean: dict,键为index,值为该组分数的均值
    # id2std: dict,键为index,值为该组分数的标准差
    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}

    # 步骤3: 在no_grad上下文中计算(不需要梯度)
    with torch.no_grad():
        # 获取batch大小
        # bsz: int类型,batch的大小
        bsz = scores.shape[0]
        
        # 步骤4: 按index(uid)分组收集分数
        # 遍历batch中的每个样本,将分数按index分组
        # index[i]: 第i个样本的uid(唯一标识符)
        # scores[i]: 第i个样本的总奖励分数
        # 相同uid的样本属于同一组(同一个prompt的多个响应)
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
        
        # 步骤5: 计算每个组的均值和标准差
        # 遍历每个组(每个唯一的index/uid)
        for idx in id2score:
            # 如果组内只有一个样本(只有一个响应)
            if len(id2score[idx]) == 1:
                # 均值设为0(因为只有一个样本,无法计算相对优势)
                # id2mean[idx]: torch.Tensor,标量,值为0.0
                id2mean[idx] = torch.tensor(0.0)
                # 标准差设为1(避免除零,同时保持数值稳定)
                # id2std[idx]: torch.Tensor,标量,值为1.0
                id2std[idx] = torch.tensor(1.0)
            # 如果组内有多个样本(多个响应)
            elif len(id2score[idx]) > 1:
                # 计算组内分数的均值
                # id2mean[idx]: torch.Tensor,标量,该组所有响应分数的平均值
                id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
                # 计算组内分数的标准差
                # id2std[idx]: torch.Tensor,标量,该组所有响应分数的标准差
                # 注意:这里使用torch.tensor([id2score[idx]]),将列表包装成tensor再计算std
                id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
            else:
                # 异常情况:组内没有样本(不应该发生)
                raise ValueError(f"no score in prompt index: {idx}")
        
        # 步骤6: 计算归一化的优势分数
        # 遍历batch中的每个样本
        for i in range(bsz):
            # 如果启用标准差归一化(原始GRPO方法)
            if norm_adv_by_std_in_grpo:
                # 计算标准化优势:(分数 - 组内均值) / (组内标准差 + epsilon)
                # scores[i]: torch.Tensor,标量,更新后的优势分数
                # id2mean[index[i]]: 该样本所在组的均值
                # id2std[index[i]]: 该样本所在组的标准差
                # epsilon: 防止除零的小值(1e-6)
                scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
            else:
                # 如果不归一化(Dr.GRPO方法):只减去均值
                # scores[i]: torch.Tensor,标量,更新后的优势分数(相对于组内均值)
                scores[i] = scores[i] - id2mean[index[i]]
        
        # 步骤7: 将标量优势扩展到token级别
        # scores.unsqueeze(-1): 从 (bs,) 扩展为 (bs, 1)
        # response_mask: (bs, response_length),标识哪些位置是响应部分
        # scores: torch.Tensor,形状从 (bs,) 变为 (bs, response_length)
        # 每个token位置都使用相同的优势值(因为GRPO是基于outcome reward的)
        scores = scores.unsqueeze(-1) * response_mask

    # 步骤8: 返回优势和回报
    # 在GRPO中,returns等于advantages(因为是基于outcome reward,没有时间折扣)
    # scores: torch.Tensor,形状 (bs, response_length),既是advantages也是returns
    return scores, scores

损失计算

update_policy:(loss=compute_policy_loss +0.001*KL loss)

def update_policy(self, data: DataProto):
    # 步骤1: 设置为训练模式
    self.actor_module.train()
    
    # 步骤2: 提取temperature
    temperature = data.meta_info["temperature"]
    
    # 步骤3: 选择需要的键(对于GRPO)
    select_keys = [
        "responses",          # 生成的响应,形状: (bs, response_length)
        "response_mask",      # 响应掩码,形状: (bs, response_length)
        "input_ids",          # 输入token IDs,形状: (bs, seq_len)
        "attention_mask",     # 注意力掩码,形状: (bs, seq_len)
        "position_ids",       # 位置IDs,形状: (bs, seq_len)
        "old_log_probs",      # 旧的对数概率,形状: (bs, response_length)
        "advantages",         # 优势(GRPO计算),形状: (bs, response_length)
    ]
    # 脚本配置: actor_rollout_ref.actor.use_kl_loss=True
    if self.config.use_kl_loss:
        select_keys.append("ref_log_prob")  # 参考策略的对数概率,形状: (bs, response_length)
    
    # 步骤4: 选择数据
    data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)
    
    # 步骤5: 分割成mini-batches
    # 脚本配置: actor_rollout_ref.actor.ppo_mini_batch_size=256
    mini_batches = data.split(self.config.ppo_mini_batch_size)
    
    metrics = {}
    # 步骤6: 多个PPO epoch(通常为1-4个epoch)
    for _ in range(self.config.ppo_epochs):
        for batch_idx, mini_batch in enumerate(mini_batches):
            # 步骤7: 进一步分割成micro-batches(用于梯度累积)
            # 脚本配置: actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32
            if self.config.use_dynamic_bsz:
                max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
                micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len)
            else:
                self.gradient_accumulation = (
                    self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
                )  # 256 // 32 = 8
                micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
            
            # 步骤8: 清零梯度
            self.actor_optimizer.zero_grad()
            
            # 步骤9: 遍历每个micro-batch
            for micro_batch in micro_batches:
                micro_batch_metrics = {}
                model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
                response_mask = model_inputs["response_mask"]      # (bs, response_length)
                old_log_prob = model_inputs["old_log_probs"]        # (bs, response_length)
                advantages = model_inputs["advantages"]              # (bs, response_length) - GRPO计算的
                
                # 步骤10: 获取配置参数
                clip_ratio = self.config.clip_ratio                 # PPO裁剪比例(通常0.2)
                entropy_coeff = self.config.entropy_coeff           # 熵系数(脚本中=0)
                loss_agg_mode = self.config.loss_agg_mode           # 损失聚合模式
                
                # 步骤11: 前向传播,计算新的log_prob和entropy
                # 脚本配置: actor_rollout_ref.actor.entropy_coeff=0,所以不计算entropy
                calculate_entropy = False
                if entropy_coeff != 0:
                    calculate_entropy = True
                entropy, log_prob = self._forward_micro_batch(
                    model_inputs, temperature=temperature, calculate_entropy=calculate_entropy
                )
                # log_prob: (bs, response_length) - 当前策略的对数概率
                
                # 步骤12: 计算策略损失(PPO clipped loss)
                pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
                    old_log_prob=old_log_prob,      # (bs, response_length)
                    log_prob=log_prob,              # (bs, response_length)
                    advantages=advantages,          # (bs, response_length) - GRPO优势
                    response_mask=response_mask,    # (bs, response_length)
                    cliprange=clip_ratio,           # 0.2
                    loss_agg_mode=loss_agg_mode,
                )
                # pg_loss: 标量,策略梯度损失
                # pg_clipfrac: 被裁剪的比例
                # ppo_kl: KL散度(用于监控)
                
                # 步骤13: 计算总损失
                # 脚本配置: entropy_coeff=0,所以不添加熵损失
                if entropy_coeff != 0:
                    entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
                    policy_loss = pg_loss - entropy_loss * entropy_coeff
                else:
                    policy_loss = pg_loss
                
                # 步骤14: 添加KL损失(如果启用)
                # 脚本配置: actor_rollout_ref.actor.use_kl_loss=True, kl_loss_coef=0.001
                if self.config.use_kl_loss:
                    ref_log_prob = model_inputs["ref_log_prob"]  # (bs, response_length)
                    # 计算KL散度惩罚
                    kld = kl_penalty(
                        logprob=log_prob,           # 当前策略
                        ref_logprob=ref_log_prob,   # 参考策略
                        kl_penalty=self.config.kl_loss_type  # "low_var_kl"
                    )
                    kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
                    policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef  # 0.001
                    micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item()
                    micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef
                
                # 步骤15: 缩放损失(用于梯度累积)
                if self.config.use_dynamic_bsz:
                    loss = policy_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size)
                else:
                    loss = policy_loss / self.gradient_accumulation  # 除以梯度累积步数
                
                # 步骤16: 反向传播
                loss.backward()
                
                # 步骤17: 记录指标
                micro_batch_metrics.update({
                    "actor/pg_loss": pg_loss.detach().item(),
                    "actor/pg_clipfrac": pg_clipfrac.detach().item(),
                    "actor/ppo_kl": ppo_kl.detach().item(),
                    "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
                })
                append_to_dict(metrics, micro_batch_metrics)
            
            # 步骤18: 优化器步骤(更新参数)
            grad_norm = self._optimizer_step()
            mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()}
            append_to_dict(metrics, mini_batch_metrics)
    
    self.actor_optimizer.zero_grad()
    return metrics

compute_policy_loss 函数详解:重要性采样、Clipped PPO 和 Dual-clip PPO

重要性采样

  1. 起点:我们想要计算什么?
    重要性采样要解决的核心问题是:我们想计算一个函数 \(f(x)\) 在目标分布 \(p(x)\) 下的期望值,但我们只有从另一个分布 \(q(x)\) 中采样的样本。
    用数学语言表达,我们的目标是计算:

\[\mathbb{E}_{x \sim p}[f(x)] = \int p(x) f(x) dx \]

但我们无法直接从 \(p(x)\) 中采样,我们只有从 \(q(x)\) 中采样的样本 \(\{x_1, x_2, \ldots, x_n\} \sim q(x)\)
2. 关键的数学“技巧”:乘以一个“1”
推导的巧妙之处在于一个简单的代数变换。我们在积分中乘以再除以 \(q(x)\),这相当于乘以1,不会改变积分的值:

\[\mathbb{E}_{x \sim p}[f(x)] = \int p(x) f(x) dx = \int \frac{p(x)}{q(x)} \cdot q(x) \cdot f(x) dx\]

为什么这么做?因为现在我们有了 \(q(x) dx\) 这一项,这正好是分布 \(q\) 的概率测度。
3. 重新解读积分:在新的分布下求期望
现在,观察上式的最右边:

\[\int\frac{p(x)}{q(x)}\cdot f(x)\cdot q(x)dx \]

这可以被看作是函数\(\frac{p(x)}{q(x)}f(x)\)在分布\(q(x)\)下的期望值:

\[\mathbb{E}_{x\sim p}[f(x)]=\mathbb{E}_{x\sim q}\left[\frac{p(x)}{q(x)}f(x)\right] \]

这就是重要性采样的核心公式
4. 从理论到实践:蒙特卡洛估计
理论上,我们有上面的精确等式。但在实践中,我们无法计算真实的期望(因为它需要无穷多的样本),那么根据大数定律,我们从 \(q(x)\) 中采集 \(n\) 个样本 \(x_i\),然后用样本的平均值来近似期望:

\[\mathbb{E}_{x \sim p}[f(x)] = \mathbb{E}_{x \sim q}\left[\frac{p(x)}{q(x)}f(x)\right] \approx \frac{1}{n} \sum_{i=1}^{n} \frac{p(x_i)}{q(x_i)}f(x_i) \]

其中,\(\frac{p(x_i)}{q(x_i)}\) 被称为 重要性权重(Importance Weight)。

  1. 直观理解重要性权重
    重要性权重 \(\frac{p(x_i)}{q(x_i)}\) 起到了一个校正作用:
  • 如果 \(p(x_i) > q(x_i)\):权重 > 1。这意味着在目标分布 \(p\) 中,\(x_i\) 出现的概率比在采样分布 \(q\) 中更高。因为我们幸运地从 \(q\) 中采到了这个在 \(p\) 中更常见的点,所以我们给它更大的权重,以补偿它在 \(q\) 中出现的低概率。
  • 如果 \(p(x_i) < q(x_i)\):权重 < 1。这意味着 \(x_i\)\(p\) 中其实不那么常见,但在 \(q\) 中却很常见。我们采到了很多这样的点,但它们对目标分布 \(p\) 的代表性不强,所以需要降低权重。
  1. 重要性采样的方差(和重要性权重成平方关系)
    加入设重要性权重为 \(\frac{p(x_i)}{q(x_i)}=w(x)\)
    现在,我们来计算这个重要性采样估计量的方差。根据方差公式 \(\text{Var}(Y) = \mathbb{E}[Y^2] - (\mathbb{E}[Y])^2\),我们可以计算估计量 \(w(x)f(x)\) 的方差:
    为了看清核心关系,我们假设 \(f(x)\) 是一个常数(比如为1),来单独看权重 \(w(x)\) 带来的影响。此时估计量就是 \(w(x)\) 本身,其期望是1(\(\mathbb{E}_{x\sim q}[w(x)]=\sum_x q(x)\times \frac{p(x)}{q(x)} = \sum_x p(x)=1\))。
    那么它的方差为:

\[\text{Var}_{x\sim q}(w(x)) = \mathbb{E}_{x\sim q}[w(x)^2] - (\mathbb{E}_{x\sim q}[w(x)])^2\]

\[= \mathbb{E}_{x\sim q}[w(x)^2] - 1 \]

看,方差就等于权重的二阶矩 \(\mathbb{E}[w^2]\) 减去1。这就是“方差与权重 \(w\) 二阶矩直接相关”的最简洁体现。
所以,如果\(p\)\(q\)的分布相差太远,那么在某处,\(w(x)\)可能会非常大,另外一处,\(w(x)\)又会非常小(因为要保证期望为1),而大的\(w(x)\)会导致\(\mathbb{E}_{x\sim q}[w(x)^2]\)很大,进而会导致重要性采样的高方差.

大数定律

定理陈述(强大数定律 - 柯尔莫哥洛夫版)
\(X_{1}, X_{2}, X_{3}, \cdots\) 是独立同分布的随机变量序列,且具有有限的期望值 \(\mu = \mathbb{E}[X_{1}]\) 和有限的方差 \(\operatorname{Var}(X_{1}) = \sigma^{2} < \infty\)
那么,样本均值几乎必然收敛于期望值 \(\mu\)

\[P\left(\lim _{n \rightarrow \infty} \frac{1}{n} \sum_{i=1}^{n} X_{i} = \mu\right) = 1 \]

记作:\(\frac{1}{n} \sum_{i=1}^{n} X_{i} \xrightarrow{a.s.} \mu\)

背景问题

在策略梯度中,我们需要计算:

使用旧策略 收集的数据来更新新策略 时,需要重要性采样。

重要性采样原理

重要性采样允许用旧策略的样本估计新策略的期望:

重要性权重为:

代码实现

# 步骤1: 计算重要性采样比率
negative_approx_kl = log_prob - old_log_prob  # log(π_new/π_old)
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)  # ratio = π_new/π_old

数学表达式:

为什么需要重要性采样?

  1. 数据复用:用旧策略的数据更新新策略

  2. 样本效率:减少重新采样

  3. 在线学习:在策略更新过程中利用历史数据

问题:重要性采样可能不稳定

当 与 差异较大时, 可能很大,导致:

  • 高方差

  • 训练不稳定

  • 策略更新过大

2. Clipped PPO(标准 PPO)

compute_policy_loss 函数详解:重要性采样、Clipped PPO 和 Dual-clip PPO

1. 重要性采样(Importance Sampling)

背景问题

在策略梯度中,我们需要计算:

\[\nabla_\theta J(\theta) = \mathbb{E}_{a \sim \pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot Q(s,a) \right] \]

使用旧策略 \(\pi_{\text{old}}\) 收集的数据来更新新策略 \(\pi_{\text{new}}\) 时,需要重要性采样。

重要性采样原理

重要性采样允许用旧策略的样本估计新策略的期望:

\[\mathbb{E}_{a \sim \pi_{\text{new}}} [f(a)] = \mathbb{E}_{a \sim \pi_{\text{old}}} \left[ \frac{\pi_{\text{new}}(a|s)}{\pi_{\text{old}}(a|s)} \cdot f(a) \right] \]

重要性权重为:

\[w(a|s) = \frac{\pi_{\text{new}}(a|s)}{\pi_{\text{old}}(a|s)} = \exp(\log \pi_{\text{new}}(a|s) - \log \pi_{\text{old}}(a|s)) \]

代码实现

# 步骤1: 计算重要性采样比率
negative_approx_kl = log_prob - old_log_prob  # log(π_new/π_old)
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl)  # ratio = π_new/π_old

数学表达式:

\[\text{ratio} = \frac{\pi_{\text{new}}(a|s)}{\pi_{\text{old}}(a|s)} = \exp(\log \pi_{\text{new}}(a|s) - \log \pi_{\text{old}}(a|s)) \]

为什么需要重要性采样?

  1. 数据复用:用旧策略的数据更新新策略
  2. 样本效率:减少重新采样
  3. 在线学习:在策略更新过程中利用历史数据

问题:重要性采样可能不稳定

\(\pi_{\text{new}}\)\(\pi_{\text{old}}\) 差异较大时,\(w(a|s)\) 可能很大,导致:

  • 高方差
  • 训练不稳定
  • 策略更新过大

2. Clipped PPO(标准 PPO)

核心思想

限制重要性权重,防止策略更新过大。

标准策略梯度损失

\[L^{\text{PG}}(\theta) = \mathbb{E}_{a \sim \pi_{\text{old}}} \left[ \frac{\pi_{\text{new}}(a|s)}{\pi_{\text{old}}(a|s)} \cdot A(s,a) \right] = \mathbb{E}_{a \sim \pi_{\text{old}}} \left[ \text{ratio} \cdot A(s,a) \right] \]

PPO Clipped 损失

\[L^{\text{CLIP}}(\theta) = \mathbb{E}_{a \sim \pi_{\text{old}}} \left[ \min\left( \text{ratio} \cdot A(s,a), \text{clip}(\text{ratio}, 1-\epsilon, 1+\epsilon) \cdot A(s,a) \right) \right] \]

其中 \(\text{clip}(\text{ratio}, 1-\epsilon, 1+\epsilon)\) 将 ratio 限制在 \([1-\epsilon, 1+\epsilon]\) 范围内。

为什么取 min?

  • \(A(s,a) > 0\)(好动作)时:
    • 如果 \(\text{ratio} > 1+\epsilon\),使用 \((1+\epsilon) \cdot A\),限制过度增加
    • 如果 \(\text{ratio} < 1+\epsilon\),使用 \(\text{ratio} \cdot A\),允许正常更新
  • \(A(s,a) < 0\)(坏动作)时:
    • 如果 \(\text{ratio} < 1-\epsilon\),使用 \((1-\epsilon) \cdot A\),限制过度减少
    • 如果 \(\text{ratio} > 1-\epsilon\),使用 \(\text{ratio} \cdot A\),允许正常更新

等价形式(最大化目标,取负号后最小化):

\[L^{\text{CLIP}}(\theta) = \mathbb{E}_{a \sim \pi_{\text{old}}} \left[ \max\left( -\text{ratio} \cdot A(s,a), -\text{clip}(\text{ratio}, 1-\epsilon, 1+\epsilon) \cdot A(s,a) \right) \right] \]

代码实现

# 步骤2: 计算标准PPO clipped loss
pg_losses1 = -advantages * ratio  # 未裁剪的损失: -ratio * A

pg_losses2 = -advantages * torch.clamp(
    ratio, 1 - cliprange_low, 1 + cliprange_high  # clip(ratio, 1-ε, 1+ε)
)  # 裁剪后的损失: -clip(ratio, 1-ε, 1+ε) * A

clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
# max(-ratio * A, -clip(ratio, 1-ε, 1+ε) * A)

数学表达式:

\[L_1 = -\text{ratio} \cdot A(s,a) \]

\[L_2 = -\text{clip}(\text{ratio}, 1-\epsilon, 1+\epsilon) \cdot A(s,a) \]

\[L^{\text{CLIP}} = \max(L_1, L_2) \]

Clipped PPO 的优势

  1. 限制更新幅度:\(\epsilon\)(如 0.2)控制单步更新范围
  2. 降低方差:避免极端权重
  3. 提高稳定性:防止策略崩溃

3. Dual-clip PPO

背景问题

\(A(s,a) < 0\)\(\text{ratio}\) 很大时,标准 PPO 可能仍允许较大的负更新,导致策略过度偏离。

核心思想

在负优势时,额外限制损失的下界,防止策略过度更新。

Dual-clip 公式

\(A(s,a) < 0\) 时,添加下界限制:

\[L^{\text{DUAL}}(\theta) = \begin{cases} L^{\text{CLIP}}(\theta) & \text{if } A(s,a) \geq 0 \\ \min(L^{\text{CLIP}}(\theta), -c \cdot A(s,a)) & \text{if } A(s,a) < 0 \end{cases} \]

其中 \(c > 1\) 是下界常数(通常 \(c = 3.0\))。

为什么需要 Dual-clip?

\(A(s,a) < 0\)\(\text{ratio}\) 很大时:

  • 标准 PPO:\(L = -\text{ratio} \cdot A\) 可能很大(负优势 × 大 ratio = 大负损失)
  • Dual-clip:限制为 \(-c \cdot A\),防止过度惩罚

代码实现

# 步骤3: 计算dual-clip PPO损失
pg_losses3 = -advantages * clip_ratio_c  # -c * A (下界)

clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
# min(-c * A, L^CLIP)

# 步骤4: 根据advantages的符号选择
pg_losses = torch.where(
    advantages < 0, 
    clip_pg_losses2,      # 负优势:使用dual-clip
    clip_pg_losses1       # 正优势:使用标准PPO
)

数学表达式:

\[L_3 = -c \cdot A(s,a) \]

\[L^{\text{DUAL}} = \begin{cases} L^{\text{CLIP}} & \text{if } A(s,a) \geq 0 \\ \min(L^{\text{CLIP}}, L_3) & \text{if } A(s,a) < 0 \end{cases} \]

Dual-clip 的优势

  1. 防止过度惩罚:在负优势时限制下界
  2. 提高稳定性:避免策略在坏动作上过度更新
  3. 适合复杂任务:在奖励稀疏或噪声大时更稳健

4. 完整损失函数总结

三阶段损失计算

\[\text{ratio} = \frac{\pi_{\text{new}}(a|s)}{\pi_{\text{old}}(a|s)} = \exp(\log \pi_{\text{new}}(a|s) - \log \pi_{\text{old}}(a|s)) \]

\[L_1 = -\text{ratio} \cdot A(s,a) \]

\[L_2 = -\text{clip}(\text{ratio}, 1-\epsilon, 1+\epsilon) \cdot A(s,a) \]

\[L^{\text{CLIP}} = \max(L_1, L_2) \]

\[L_3 = -c \cdot A(s,a) \]

\[L^{\text{DUAL}} = \begin{cases} L^{\text{CLIP}} & \text{if } A(s,a) \geq 0 \\ \min(L^{\text{CLIP}}, L_3) & \text{if } A(s,a) < 0 \end{cases} \]

最终损失(聚合后):

\[L = \mathbb{E}_{a \sim \pi_{\text{old}}} \left[ L^{\text{DUAL}} \right] \]

可视化理解

对于正优势(\(A > 0\)):

  • 如果 \(\text{ratio} < 1+\epsilon\):使用 \(L_1 = -\text{ratio} \cdot A\)
  • 如果 \(\text{ratio} > 1+\epsilon\):使用 \(L_2 = -(1+\epsilon) \cdot A\)(裁剪)

对于负优势(\(A < 0\)):

  • 如果 \(\text{ratio} > 1-\epsilon\):使用 \(L_1 = -\text{ratio} \cdot A\)
  • 如果 \(\text{ratio} < 1-\epsilon\):使用 \(L_2 = -(1-\epsilon) \cdot A\)(裁剪)
  • 如果 \(L^{\text{CLIP}} < -c \cdot A\):进一步限制为 \(-c \cdot A\)(dual-clip)

5. 为什么这样设计?

重要性采样

  • 允许用旧数据更新新策略
  • 提高样本效率

Clipped PPO

  • 限制更新幅度(\(\epsilon = 0.2\)
  • 降低方差,提高稳定性

Dual-clip PPO

  • 在负优势时额外限制下界(\(c = 3.0\)
  • 防止策略在坏动作上过度偏离

6. 对于 GRPO 的默认值

根据脚本配置:

参数 作用
cliprange 0.2 标准 PPO 裁剪范围 \(\epsilon\)
cliprange_low 0.2 下界裁剪范围
cliprange_high 0.2 上界裁剪范围
clip_ratio_c 3.0 Dual-clip 下界常数 \(c\)
loss_agg_mode "token-mean" 对所有有效 token 求平均

7. 完整代码流程

def compute_policy_loss(
    old_log_prob,      # log π_old(a|s)
    log_prob,          # log π_new(a|s)
    advantages,        # A(s,a)
    response_mask,     # 掩码
    cliprange=0.2,     # ε
    clip_ratio_c=3.0,  # c
    loss_agg_mode="token-mean"
):
    # 1. 重要性采样比率
    ratio = exp(log_prob - old_log_prob)  # π_new/π_old
    
    # 2. 标准PPO clipped loss
    L1 = -ratio * advantages
    L2 = -clip(ratio, 1-ε, 1+ε) * advantages
    L_CLIP = max(L1, L2)
    
    # 3. Dual-clip(仅当advantages < 0时)
    L3 = -c * advantages
    L_DUAL = where(advantages < 0, min(L_CLIP, L3), L_CLIP)
    
    # 4. 聚合为标量
    loss = mean(L_DUAL, mask=response_mask)
    
    return loss, clipfrac, kl, clipfrac_lower

总结

  • 重要性采样:用旧策略数据更新新策略
  • Clipped PPO:限制更新幅度(\(\epsilon = 0.2\)
  • Dual-clip PPO:在负优势时额外限制下界(\(c = 3.0\)

三者结合,在保证样本效率的同时,提高训练稳定性,适用于 GRPO 等强化学习算法。

kl_penalty 函数完整讲解(K1,K2,K3)

1. 函数概述

kl_penalty 用于计算当前策略 π_actor 与参考策略 π_ref 之间的 KL 散度,作为惩罚项加入损失,防止策略更新过大。

2. KL 散度的标准定义

对于两个概率分布 π_ref 和 π_actor,KL 散度定义为:

\[\text{KL}(\pi_{\text{ref}} \| \pi_{\text{actor}}) = \mathbb{E}_{a \sim \pi_{\text{ref}}} \left[ \log \frac{\pi_{\text{ref}}(a|s)}{\pi_{\text{actor}}(a|s)} \right] \]

展开后:

\[\text{KL}(\pi_{\text{ref}} \| \pi_{\text{actor}}) = \mathbb{E}_{a \sim \pi_{\text{ref}}} \left[ \log \pi_{\text{ref}}(a|s) - \log \pi_{\text{actor}}(a|s) \right] \]

其中:

  • ref_logprob = log π_ref(a|s)
  • logprob = log π_actor(a|s)

3. 三种 KL 散度估计方法(K1、K2、K3)

K1 估计器(一阶近似,高方差)

公式:

\[\text{KL}_{\text{K1}} = \mathbb{E}_{a \sim \pi_{\text{ref}}} \left[ \log \frac{\pi_{\text{ref}}(a|s)}{\pi_{\text{actor}}(a|s)} \right] = \mathbb{E}_{a \sim \pi_{\text{ref}}} \left[ \log \pi_{\text{ref}}(a|s) - \log \pi_{\text{actor}}(a|s) \right] \]

代码实现:

if kl_penalty in ("kl", "k1"):
    return logprob - ref_logprob
    # 注意:代码中返回的是 logprob - ref_logprob
    # 但标准KL应该是 ref_logprob - logprob
    # 这里可能是为了与损失函数中的符号一致(损失中会取负号)

特点:

  • 无偏估计
  • 方差较高,可能不稳定
  • 计算简单

K2 估计器(二阶近似,低方差)

公式:

\[\text{KL}_{\text{K2}} = \frac{1}{2} \mathbb{E}_{a \sim \pi_{\text{ref}}} \left[ \left( \log \frac{\pi_{\text{ref}}(a|s)}{\pi_{\text{actor}}(a|s)} \right)^2 \right] = \frac{1}{2} \mathbb{E}_{a \sim \pi_{\text{ref}}} \left[ \left( \log \pi_{\text{ref}}(a|s) - \log \pi_{\text{actor}}(a|s) \right)^2 \right] \]

代码实现:

if kl_penalty in ("mse", "k2"):
    return 0.5 * (logprob - ref_logprob).square()
    # = 0.5 * (logprob - ref_logprob)²

特点:

  • 有偏估计(近似)
  • 方差较低,更稳定
  • 使用平方项平滑估计

K3 估计器(低方差近似,推荐)

公式推导:

使用泰勒展开近似 KL 散度:

\[\text{KL}(\pi_{\text{ref}} \| \pi_{\text{actor}}) = \mathbb{E}_{a \sim \pi_{\text{ref}}} \left[ \log \frac{\pi_{\text{ref}}(a|s)}{\pi_{\text{actor}}(a|s)} \right] \]

定义比率:

\[r = \frac{\pi_{\text{ref}}(a|s)}{\pi_{\text{actor}}(a|s)} = \exp(\log \pi_{\text{ref}}(a|s) - \log \pi_{\text{actor}}(a|s)) \]

K3 估计器公式:

\[\text{KL}_{\text{K3}} = \mathbb{E}_{a \sim \pi_{\text{ref}}} \left[ \frac{\pi_{\text{ref}}(a|s)}{\pi_{\text{actor}}(a|s)} - 1 - \log \frac{\pi_{\text{ref}}(a|s)}{\pi_{\text{actor}}(a|s)} \right] \]

展开为:

\[\text{KL}_{\text{K3}} = \mathbb{E}_{a \sim \pi_{\text{ref}}} \left[ \exp(\log \pi_{\text{ref}}(a|s) - \log \pi_{\text{actor}}(a|s)) - (\log \pi_{\text{ref}}(a|s) - \log \pi_{\text{actor}}(a|s)) - 1 \right] \]

代码实现:

if kl_penalty in ("low_var_kl", "k3"):
    kl = ref_logprob - logprob                    # log(π_ref/π_actor)
    kl = torch.clamp(kl, min=-20, max=20)         # 数值稳定性裁剪
    ratio = torch.exp(kl)                          # ratio = π_ref/π_actor
    kld = (ratio - kl - 1)                        # K3公式
    return torch.clamp(kld, min=-10, max=10)      # 最终裁剪

数学表达式:

\[\text{KL}_{\text{K3}} = \exp(\text{ref\_logprob} - \text{logprob}) - (\text{ref\_logprob} - \text{logprob}) - 1 \]

其中:

  • \(\text{ref\_logprob} - \text{logprob} = \log \frac{\pi_{\text{ref}}}{\pi_{\text{actor}}}\)
  • \(\exp(\text{ref\_logprob} - \text{logprob}) = \frac{\pi_{\text{ref}}}{\pi_{\text{actor}}}\)

特点:

4. 其他方法

"abs"(绝对值)

\[\text{KL}_{\text{abs}} = \left| \log \pi_{\text{ref}}(a|s) - \log \pi_{\text{actor}}(a|s) \right| \]

代码:

if kl_penalty == "abs":
    return (logprob - ref_logprob).abs()

5. 完整代码实现

def kl_penalty(
    logprob: torch.FloatTensor,      # 形状: (bs, response_length)
    ref_logprob: torch.FloatTensor,  # 形状: (bs, response_length)
    kl_penalty: str                  # "kl"/"k1", "mse"/"k2", "low_var_kl"/"k3", "abs"
) -> torch.FloatTensor:              # 返回: (bs, response_length)
    """
    计算KL散度惩罚项
    
    Args:
        logprob: 当前策略的对数概率 log π_actor(a|s)
        ref_logprob: 参考策略的对数概率 log π_ref(a|s)
        kl_penalty: KL散度计算类型
        
    Returns:
        KL散度估计值,形状与输入相同
    """
    
    # K1估计器:一阶近似
    if kl_penalty in ("kl", "k1"):
        # KL ≈ log π_actor - log π_ref
        # 注意:这里返回的是负的KL(因为损失函数中会取负号)
        return logprob - ref_logprob
    
    # 绝对值方法
    if kl_penalty == "abs":
        # KL ≈ |log π_actor - log π_ref|
        return (logprob - ref_logprob).abs()
    
    # K2估计器:二阶近似
    if kl_penalty in ("mse", "k2"):
        # KL ≈ 0.5 * (log π_actor - log π_ref)²
        return 0.5 * (logprob - ref_logprob).square()
    
    # K3估计器:低方差近似(推荐)
    if kl_penalty in ("low_var_kl", "k3"):
        # kl = log(π_ref/π_actor) = ref_logprob - logprob
        kl = ref_logprob - logprob
        
        # 数值稳定性裁剪:防止exp溢出
        kl = torch.clamp(kl, min=-20, max=20)
        
        # ratio = π_ref/π_actor = exp(log(π_ref/π_actor))
        ratio = torch.exp(kl)
        
        # K3公式:KL = ratio - log(ratio) - 1
        # 其中 log(ratio) = kl = ref_logprob - logprob
        kld = (ratio - kl - 1)
        
        # 最终裁剪:防止数值不稳定
        return torch.clamp(kld, min=-10, max=10)
    
    # 完整KL散度(需要完整logits,未实现)
    if kl_penalty == "full":
        raise NotImplementedError
    
    raise NotImplementedError(f"Unknown kl_penalty type: {kl_penalty}")

6. 在 PPO 训练中的使用

update_policy 中的使用流程:

# 1. 计算KL散度(每个token位置)
kld = kl_penalty(
    logprob=log_prob,                    # 当前策略的对数概率
    ref_logprob=ref_log_prob,            # 参考策略的对数概率
    kl_penalty="low_var_kl"              # 使用K3估计器
)  # 形状: (bs, response_length)

# 2. 聚合为标量损失
kl_loss = agg_loss(
    loss_mat=kld,                        # KL散度矩阵
    loss_mask=response_mask,              # 响应掩码
    loss_agg_mode="token-mean"           # 对所有有效token求平均
)  # 标量

# 3. 添加到总损失中
policy_loss = policy_loss + kl_loss * kl_loss_coef
#                                    = kl_loss * 0.001

7. 对于 run_qwen3-8b.sh 脚本

脚本配置:

actor_rollout_ref.actor.kl_loss_type=low_var_kl  # 使用K3估计器
actor_rollout_ref.actor.kl_loss_coef=0.001       # KL损失系数

因此使用 K3 估计器:

\[\text{KL}_{\text{K3}} = \exp(\text{ref\_logprob} - \text{logprob}) - (\text{ref\_logprob} - \text{logprob}) - 1 \]

8. 三种估计器的比较

估计器 公式 无偏性 方差 稳定性 推荐度
K1 \(\log \pi_{\text{ref}} - \log \pi_{\text{actor}}\) ✅ 无偏 ❌ 高 ❌ 不稳定 ⭐⭐
K2 \(\frac{1}{2}(\log \pi_{\text{ref}} - \log \pi_{\text{actor}})^2\) ❌ 有偏 ✅ 低 ✅ 稳定 ⭐⭐⭐
K3 \(\exp(\Delta) - \Delta - 1\),其中 \(\Delta = \log \pi_{\text{ref}} - \log \pi_{\text{actor}}\) ✅ 无偏 ✅ 低 ✅ 稳定 ⭐⭐⭐⭐⭐

9. 数学推导(K3 的推导)

K3 基于以下近似:

对于小的 \(x = \log \frac{\pi_{\text{ref}}}{\pi_{\text{actor}}}\),使用泰勒展开:

\[e^x = 1 + x + \frac{x^2}{2!} + \frac{x^3}{3!} + \cdots \]

\[\log(1 + x) = x - \frac{x^2}{2} + \frac{x^3}{3} - \cdots \]

KL 散度可以写成:

\[\text{KL} = \mathbb{E}[x] = \mathbb{E}\left[\log \frac{\pi_{\text{ref}}}{\pi_{\text{actor}}}\right] \]

使用恒等式 \(x = e^x - 1 - (e^x - 1 - x)\),当 \(x\) 接近 0 时:

\[x \approx e^x - 1 - (e^x - 1 - x) = e^x - x - 1 \]

因此:

\[\text{KL} \approx \mathbb{E}[e^x - x - 1] = \mathbb{E}\left[\exp\left(\log \frac{\pi_{\text{ref}}}{\pi_{\text{actor}}}\right) - \log \frac{\pi_{\text{ref}}}{\pi_{\text{actor}}} - 1\right] \]

这就是 K3 估计器的公式。

总结

  • K1:简单但方差高
  • K2:稳定但有偏
  • K3:无偏且低方差,适合 PPO

脚本使用 K3(low_var_kl),系数为 0.001,有助于稳定训练并限制策略更新幅度。

posted @ 2025-12-16 10:56  Brain404  阅读(30)  评论(0)    收藏  举报