VERL-GRPO实现
VERL-GRPO 源码分析
脚本为verl v0.5.0中的快速开始脚本
# Tested successfully on the hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0 image.
# It outperforms the Qwen2 7B base model by two percentage points on the test set of GSM8K.
set -x
data_dir="/home/hadoop-aipnlp/dolphinfs_ssd_hadoop-aipnlp/EVA/liruihan05/projects/verl/data/gsm8k"
model_path="/home/hadoop-aipnlp/dolphinfs_ssd_hadoop-aipnlp/EVA/liruihan05/models/open-sources/Qwen/Qwen3-8B"
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=${data_dir}/train.parquet \
data.val_files=${data_dir}/test.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=Qwen/Qwen3-8B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console"]' \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='qwen3_8b_function_rm' \
trainer.n_gpus_per_node=4 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
RayPPOTrainer.fit()
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
breakpoint() # ✅ 会执行:调试断点
from omegaconf import OmegaConf # ✅ 会执行:导入配置管理库
from verl.utils.tracking import Tracking # ✅ 会执行:导入追踪工具类
# ✅ 会执行:初始化日志追踪器
# logger: Tracking对象,用于记录训练指标和配置
logger = Tracking(
project_name=self.config.trainer.project_name, # 'verl_grpo_example_gsm8k'
experiment_name=self.config.trainer.experiment_name, # 'qwen3_8b_function_rm'
default_backend=self.config.trainer.logger, # ['console']
config=OmegaConf.to_container(self.config, resolve=True),
)
# ✅ 会执行:初始化全局步数计数器
self.global_steps = 0
# ✅ 会执行:在开始训练前加载检查点
self._load_checkpoint()
# ❓ 可能执行:取决于val_reward_fn是否配置
# 脚本中配置了data.val_files,所以可能有验证函数
# self.val_reward_fn: 可能为None或函数对象
# self.config.trainer.get("val_before_train", True): 默认为True
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
# ❓ 可能执行:如果val_reward_fn不为None
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
# ❌ 不会执行:脚本中没有配置val_only
if self.config.trainer.get("val_only", False):
return
# ✅ 会执行:添加进度条
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# ✅ 会执行:从步数1开始
self.global_steps += 1
last_val_metrics = None
self.max_steps_duration = 0
# ✅ 会执行:外层循环,总共15个epoch(trainer.total_epochs=15)
for epoch in range(self.config.trainer.total_epochs):
# ✅ 会执行:内层循环,遍历训练数据加载器
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
# ❌ 不会执行:脚本中没有配置profile_steps,所以do_profile=False
do_profile = (
self.global_steps in self.config.trainer.profile_steps
if self.config.trainer.profile_steps is not None
else False
)
# ✅ 会执行:但do_profile=False,所以不会真正进行性能分析
with marked_timer("start_profile", timing_raw):
self._start_profiling(do_profile)
# ✅ 会执行:将batch字典转换为DataProto对象
batch: DataProto = DataProto.from_single_dict(batch_dict)
# ✅ 会执行:准备用于生成时需要弹出的键列表
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
# ❓ 可能执行:取决于batch中是否包含这些键
if "multi_modal_data" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("multi_modal_data")
if "raw_prompt" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("tools_kwargs")
if "interaction_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("interaction_kwargs")
if "index" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("index")
if "agent_name" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("agent_name")
# ✅ 会执行:从batch中弹出生成所需的键
gen_batch = batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)
# ✅ 会执行:传递全局步数到trace
gen_batch.meta_info["global_steps"] = self.global_steps
# ✅ 会执行:重复batch,n=5(actor_rollout_ref.rollout.n=5)
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
# ✅ 会执行:判断是否为最后一步
is_last_step = self.global_steps >= self.total_training_steps
with marked_timer("step", timing_raw):
# ✅ 会执行:生成阶段
with marked_timer("gen", timing_raw, color="red"):
# ✅ 会执行:脚本中没有配置async_rollout_mode,默认为False
if not self.async_rollout_mode:
# ✅ 会执行:同步模式生成
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
else:
# ❌ 不会执行:async_rollout_mode=False
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
# ❌ 不会执行:adv_estimator=grpo,不是REMAX
# 脚本配置:algorithm.adv_estimator=grpo
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
# ❌ 不会执行:REMAX分支
with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
if not self.async_rollout_mode:
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
else:
gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
# ✅ 会执行:为每个样本生成唯一ID
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# ✅ 会执行:重复batch以对齐rollout中重复的响应,n=5
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
# ✅ 会执行:将生成的响应合并到batch中
batch = batch.union(gen_batch_output)
# ❓ 可能执行:如果batch中没有response_mask则计算
if "response_mask" not in batch.batch.keys():
# ❓ 可能执行:取决于是否已有response_mask
batch.batch["response_mask"] = compute_response_mask(batch)
# ❌ 不会执行:脚本中没有配置balance_batch,默认为False
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# ✅ 会执行:计算全局有效token数量
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
with marked_timer("reward", timing_raw, color="yellow"):
# ❌ 不会执行:GRPO不使用奖励模型(use_rm=False,因为Role.RewardModel不在role_worker_mapping中)
# 脚本中没有配置reward_model相关参数
if self.use_rm:
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
# ❌ 不会执行:脚本中没有配置launch_reward_fn_async,默认为False
if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)
else:
# ✅ 会执行:同步计算奖励
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
# ✅ 会执行:重新计算旧的对数概率
with marked_timer("old_log_prob", timing_raw, color="blue"):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
# ❓ 可能执行:取决于rollout时是否保存了log_probs
# 如果rollout阶段保存了log_probs,则会执行此分支用于诊断
if "rollout_log_probs" in batch.batch.keys():
# ❓ 可能执行:如果存在rollout_log_probs
rollout_old_log_probs = batch.batch["rollout_log_probs"]
actor_old_log_probs = batch.batch["old_log_probs"]
attention_mask = batch.batch["attention_mask"]
responses = batch.batch["responses"]
response_length = responses.size(1)
response_mask = attention_mask[:, -response_length:]
rollout_probs = torch.exp(rollout_old_log_probs)
actor_probs = torch.exp(actor_old_log_probs)
rollout_probs_diff = torch.abs(rollout_probs - actor_probs)
rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())
rollout_probs_diff_max = torch.max(rollout_probs_diff)
rollout_probs_diff_mean = torch.mean(rollout_probs_diff)
rollout_probs_diff_std = torch.std(rollout_probs_diff)
metrics.update(
{
"training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(),
"training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(),
"training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(),
}
)
# ❓ 可能执行:取决于是否使用参考策略
# 脚本中配置了actor_rollout_ref.ref相关参数,所以很可能使用参考策略
# 如果use_reference_policy=True(Role.RefPolicy在role_worker_mapping中)
if self.use_reference_policy:
# ❓ 可能执行:如果使用参考策略
with marked_timer("ref", timing_raw, color="olive"):
# ❓ 可能执行:取决于ref_in_actor的值
# ref_in_actor = (lora_rank > 0),脚本中没有配置lora,所以ref_in_actor=False
if not self.ref_in_actor:
# ✅ 如果使用参考策略,会执行此分支(因为ref_in_actor=False)
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
# ❌ 不会执行:ref_in_actor=False
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# ❌ 不会执行:GRPO不使用critic(use_critic=False)
# 脚本配置:algorithm.adv_estimator=grpo
# 根据代码372-384行,GRPO在列表中,所以use_critic=False
if self.use_critic:
with marked_timer("values", timing_raw, color="cyan"):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with marked_timer("adv", timing_raw, color="brown"):
# ✅ 会执行:reward_extra_infos_dict类型注解
reward_extra_infos_dict: dict[str, list]
# ❌ 不会执行:launch_reward_fn_async=False
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
# ✅ 会执行:将token级别的分数添加到batch中
batch.batch["token_level_scores"] = reward_tensor
# ❓ 可能执行:如果reward_extra_infos_dict不为空
if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
# ❌ 不会执行:脚本配置:algorithm.use_kl_in_reward=False
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
# ✅ 会执行:直接将分数作为奖励(因为use_kl_in_reward=False)
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# ✅ 会执行:计算优势函数
# norm_adv_by_std_in_grpo: 默认为True(GRPO优势归一化因子)
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
)
# ✅ 会执行:使用GRPO计算优势(调用compute_grpo_outcome_advantage函数)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator, # GRPO
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n, # 5
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
)
# ❌ 不会执行:GRPO不使用critic(use_critic=False)
if self.use_critic:
with marked_timer("update_critic", timing_raw, color="pink"):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# ✅ 会执行:critic_warmup=0,所以条件总是满足(0 <= global_steps)
# 脚本配置:trainer.critic_warmup=0
if self.config.trainer.critic_warmup <= self.global_steps:
# ✅ 会执行:更新actor
with marked_timer("update_actor", timing_raw, color="red"):
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
# 核心函数是(update_policy)
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# ❌ 不会执行:脚本中没有配置rollout_data_dir,默认为None
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
with marked_timer("dump_rollout_generations", timing_raw, color="green"):
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
if "request_id" in batch.non_tensor_batch:
reward_extra_infos_dict.setdefault(
"request_id",
batch.non_tensor_batch["request_id"].tolist(),
)
self._dump_generations(
inputs=inputs,
outputs=outputs,
scores=scores,
reward_extra_infos_dict=reward_extra_infos_dict,
dump_path=rollout_data_dir,
)
# ❓ 可能执行:取决于val_reward_fn是否配置,以及步数
# 脚本配置:trainer.test_freq=5,所以每5步会执行一次验证
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0 # 5 > 0,满足
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) # 每5步或最后一步
):
# ❓ 可能执行:如果满足条件(每5步或最后一步,且val_reward_fn不为None)
with marked_timer("testing", timing_raw, color="green"):
val_metrics: dict = self._validate()
# ❓ 可能执行:如果是最后一步
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
# ✅ 会执行:检查ESI是否即将到期
esi_close_to_expiration = should_save_ckpt_esi(
max_steps_duration=self.max_steps_duration,
redundant_time=self.config.trainer.esi_redundant_time,
)
# ✅ 会执行:检查保存检查点的条件
# 脚本配置:trainer.save_freq=20,所以每20步会保存一次
if self.config.trainer.save_freq > 0 and ( # 20 > 0,满足
is_last_step
or self.global_steps % self.config.trainer.save_freq == 0 # 每20步
or esi_close_to_expiration
):
# ❓ 可能执行:如果ESI即将到期
if esi_close_to_expiration:
print("Force saving checkpoint: ESI instance expiration approaching.")
# ✅ 会执行:如果满足保存条件(每20步或最后一步或ESI到期)
with marked_timer("save_checkpoint", timing_raw, color="green"):
self._save_checkpoint()
# ✅ 会执行:停止性能分析
with marked_timer("stop_profile", timing_raw):
self._stop_profiling(do_profile)
# ✅ 会执行:获取当前步骤的持续时间
steps_duration = timing_raw["step"]
self.max_steps_duration = max(self.max_steps_duration, steps_duration)
# ✅ 会执行:更新训练指标
metrics.update(
{
"training/global_step": self.global_steps,
"training/epoch": epoch,
}
)
# ✅ 会执行:收集数据指标(use_critic=False)
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
# ✅ 会执行:收集计时指标
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# ✅ 会执行:获取GPU数量
n_gpus = self.resource_pool_manager.get_n_gpus()
# ✅ 会执行:收集吞吐量指标
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
# ❓ 可能执行:取决于采样器类型
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
# ❓ 可能执行:如果使用课程学习采样器
self.train_dataloader.sampler.update(batch=batch)
# ✅ 会执行:记录指标到日志系统
logger.log(data=metrics, step=self.global_steps)
# ✅ 会执行:更新进度条
progress_bar.update(1)
# ✅ 会执行:增加全局步数
self.global_steps += 1
# ❓ 可能执行:如果是最后一步
if is_last_step:
# ❓ 可能执行:如果是最后一步
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
# ❓ 可能执行:取决于数据集是否有on_batch_end方法
if hasattr(self.train_dataset, "on_batch_end"):
# ❓ 可能执行:如果数据集有on_batch_end方法
self.train_dataset.on_batch_end(batch=batch)
优势计算
compute_grpo_outcome_advantage:
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape is (bs, response_length)
response_mask: `(torch.Tensor)`
shape is (bs, response_length)
index: `(np.ndarray)`
index array for grouping
epsilon: `(float)`
small value to avoid division by zero
norm_adv_by_std_in_grpo: `(bool)`
whether to scale the GRPO advantage
config: `(Optional[AlgoConfig])`
algorithm configuration object
Note:
If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO.
If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
Returns:
advantages: `(torch.Tensor)`
shape is (bs, response_length)
Returns: `(torch.Tensor)`
shape is (bs, response_length)
"""
# 步骤1: 计算每个样本的总奖励(对token维度求和)
# scores: torch.Tensor,形状从 (bs, response_length) 变为 (bs,)
# 将每个响应的所有token奖励求和,得到每个响应的总奖励分数
scores = token_level_rewards.sum(dim=-1)
# 步骤2: 初始化用于存储分组信息的字典
# id2score: defaultdict(list),键为index(uid),值为该组所有样本的分数列表
# id2mean: dict,键为index,值为该组分数的均值
# id2std: dict,键为index,值为该组分数的标准差
id2score = defaultdict(list)
id2mean = {}
id2std = {}
# 步骤3: 在no_grad上下文中计算(不需要梯度)
with torch.no_grad():
# 获取batch大小
# bsz: int类型,batch的大小
bsz = scores.shape[0]
# 步骤4: 按index(uid)分组收集分数
# 遍历batch中的每个样本,将分数按index分组
# index[i]: 第i个样本的uid(唯一标识符)
# scores[i]: 第i个样本的总奖励分数
# 相同uid的样本属于同一组(同一个prompt的多个响应)
for i in range(bsz):
id2score[index[i]].append(scores[i])
# 步骤5: 计算每个组的均值和标准差
# 遍历每个组(每个唯一的index/uid)
for idx in id2score:
# 如果组内只有一个样本(只有一个响应)
if len(id2score[idx]) == 1:
# 均值设为0(因为只有一个样本,无法计算相对优势)
# id2mean[idx]: torch.Tensor,标量,值为0.0
id2mean[idx] = torch.tensor(0.0)
# 标准差设为1(避免除零,同时保持数值稳定)
# id2std[idx]: torch.Tensor,标量,值为1.0
id2std[idx] = torch.tensor(1.0)
# 如果组内有多个样本(多个响应)
elif len(id2score[idx]) > 1:
# 计算组内分数的均值
# id2mean[idx]: torch.Tensor,标量,该组所有响应分数的平均值
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
# 计算组内分数的标准差
# id2std[idx]: torch.Tensor,标量,该组所有响应分数的标准差
# 注意:这里使用torch.tensor([id2score[idx]]),将列表包装成tensor再计算std
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
# 异常情况:组内没有样本(不应该发生)
raise ValueError(f"no score in prompt index: {idx}")
# 步骤6: 计算归一化的优势分数
# 遍历batch中的每个样本
for i in range(bsz):
# 如果启用标准差归一化(原始GRPO方法)
if norm_adv_by_std_in_grpo:
# 计算标准化优势:(分数 - 组内均值) / (组内标准差 + epsilon)
# scores[i]: torch.Tensor,标量,更新后的优势分数
# id2mean[index[i]]: 该样本所在组的均值
# id2std[index[i]]: 该样本所在组的标准差
# epsilon: 防止除零的小值(1e-6)
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
else:
# 如果不归一化(Dr.GRPO方法):只减去均值
# scores[i]: torch.Tensor,标量,更新后的优势分数(相对于组内均值)
scores[i] = scores[i] - id2mean[index[i]]
# 步骤7: 将标量优势扩展到token级别
# scores.unsqueeze(-1): 从 (bs,) 扩展为 (bs, 1)
# response_mask: (bs, response_length),标识哪些位置是响应部分
# scores: torch.Tensor,形状从 (bs,) 变为 (bs, response_length)
# 每个token位置都使用相同的优势值(因为GRPO是基于outcome reward的)
scores = scores.unsqueeze(-1) * response_mask
# 步骤8: 返回优势和回报
# 在GRPO中,returns等于advantages(因为是基于outcome reward,没有时间折扣)
# scores: torch.Tensor,形状 (bs, response_length),既是advantages也是returns
return scores, scores
损失计算
update_policy:(loss=compute_policy_loss +0.001*KL loss)
def update_policy(self, data: DataProto):
# 步骤1: 设置为训练模式
self.actor_module.train()
# 步骤2: 提取temperature
temperature = data.meta_info["temperature"]
# 步骤3: 选择需要的键(对于GRPO)
select_keys = [
"responses", # 生成的响应,形状: (bs, response_length)
"response_mask", # 响应掩码,形状: (bs, response_length)
"input_ids", # 输入token IDs,形状: (bs, seq_len)
"attention_mask", # 注意力掩码,形状: (bs, seq_len)
"position_ids", # 位置IDs,形状: (bs, seq_len)
"old_log_probs", # 旧的对数概率,形状: (bs, response_length)
"advantages", # 优势(GRPO计算),形状: (bs, response_length)
]
# 脚本配置: actor_rollout_ref.actor.use_kl_loss=True
if self.config.use_kl_loss:
select_keys.append("ref_log_prob") # 参考策略的对数概率,形状: (bs, response_length)
# 步骤4: 选择数据
data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)
# 步骤5: 分割成mini-batches
# 脚本配置: actor_rollout_ref.actor.ppo_mini_batch_size=256
mini_batches = data.split(self.config.ppo_mini_batch_size)
metrics = {}
# 步骤6: 多个PPO epoch(通常为1-4个epoch)
for _ in range(self.config.ppo_epochs):
for batch_idx, mini_batch in enumerate(mini_batches):
# 步骤7: 进一步分割成micro-batches(用于梯度累积)
# 脚本配置: actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len)
else:
self.gradient_accumulation = (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
) # 256 // 32 = 8
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
# 步骤8: 清零梯度
self.actor_optimizer.zero_grad()
# 步骤9: 遍历每个micro-batch
for micro_batch in micro_batches:
micro_batch_metrics = {}
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"] # (bs, response_length)
old_log_prob = model_inputs["old_log_probs"] # (bs, response_length)
advantages = model_inputs["advantages"] # (bs, response_length) - GRPO计算的
# 步骤10: 获取配置参数
clip_ratio = self.config.clip_ratio # PPO裁剪比例(通常0.2)
entropy_coeff = self.config.entropy_coeff # 熵系数(脚本中=0)
loss_agg_mode = self.config.loss_agg_mode # 损失聚合模式
# 步骤11: 前向传播,计算新的log_prob和entropy
# 脚本配置: actor_rollout_ref.actor.entropy_coeff=0,所以不计算entropy
calculate_entropy = False
if entropy_coeff != 0:
calculate_entropy = True
entropy, log_prob = self._forward_micro_batch(
model_inputs, temperature=temperature, calculate_entropy=calculate_entropy
)
# log_prob: (bs, response_length) - 当前策略的对数概率
# 步骤12: 计算策略损失(PPO clipped loss)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob, # (bs, response_length)
log_prob=log_prob, # (bs, response_length)
advantages=advantages, # (bs, response_length) - GRPO优势
response_mask=response_mask, # (bs, response_length)
cliprange=clip_ratio, # 0.2
loss_agg_mode=loss_agg_mode,
)
# pg_loss: 标量,策略梯度损失
# pg_clipfrac: 被裁剪的比例
# ppo_kl: KL散度(用于监控)
# 步骤13: 计算总损失
# 脚本配置: entropy_coeff=0,所以不添加熵损失
if entropy_coeff != 0:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
policy_loss = pg_loss - entropy_loss * entropy_coeff
else:
policy_loss = pg_loss
# 步骤14: 添加KL损失(如果启用)
# 脚本配置: actor_rollout_ref.actor.use_kl_loss=True, kl_loss_coef=0.001
if self.config.use_kl_loss:
ref_log_prob = model_inputs["ref_log_prob"] # (bs, response_length)
# 计算KL散度惩罚
kld = kl_penalty(
logprob=log_prob, # 当前策略
ref_logprob=ref_log_prob, # 参考策略
kl_penalty=self.config.kl_loss_type # "low_var_kl"
)
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef # 0.001
micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item()
micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef
# 步骤15: 缩放损失(用于梯度累积)
if self.config.use_dynamic_bsz:
loss = policy_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size)
else:
loss = policy_loss / self.gradient_accumulation # 除以梯度累积步数
# 步骤16: 反向传播
loss.backward()
# 步骤17: 记录指标
micro_batch_metrics.update({
"actor/pg_loss": pg_loss.detach().item(),
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
})
append_to_dict(metrics, micro_batch_metrics)
# 步骤18: 优化器步骤(更新参数)
grad_norm = self._optimizer_step()
mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, mini_batch_metrics)
self.actor_optimizer.zero_grad()
return metrics
compute_policy_loss 函数详解:重要性采样、Clipped PPO 和 Dual-clip PPO
重要性采样
- 起点:我们想要计算什么?
重要性采样要解决的核心问题是:我们想计算一个函数 \(f(x)\) 在目标分布 \(p(x)\) 下的期望值,但我们只有从另一个分布 \(q(x)\) 中采样的样本。
用数学语言表达,我们的目标是计算:
但我们无法直接从 \(p(x)\) 中采样,我们只有从 \(q(x)\) 中采样的样本 \(\{x_1, x_2, \ldots, x_n\} \sim q(x)\)。
2. 关键的数学“技巧”:乘以一个“1”
推导的巧妙之处在于一个简单的代数变换。我们在积分中乘以再除以 \(q(x)\),这相当于乘以1,不会改变积分的值:
为什么这么做?因为现在我们有了 \(q(x) dx\) 这一项,这正好是分布 \(q\) 的概率测度。
3. 重新解读积分:在新的分布下求期望
现在,观察上式的最右边:
这可以被看作是函数\(\frac{p(x)}{q(x)}f(x)\)在分布\(q(x)\)下的期望值:
这就是重要性采样的核心公式
4. 从理论到实践:蒙特卡洛估计
理论上,我们有上面的精确等式。但在实践中,我们无法计算真实的期望(因为它需要无穷多的样本),那么根据大数定律,我们从 \(q(x)\) 中采集 \(n\) 个样本 \(x_i\),然后用样本的平均值来近似期望:
其中,\(\frac{p(x_i)}{q(x_i)}\) 被称为 重要性权重(Importance Weight)。
- 直观理解重要性权重
重要性权重 \(\frac{p(x_i)}{q(x_i)}\) 起到了一个校正作用:
- 如果 \(p(x_i) > q(x_i)\):权重 > 1。这意味着在目标分布 \(p\) 中,\(x_i\) 出现的概率比在采样分布 \(q\) 中更高。因为我们幸运地从 \(q\) 中采到了这个在 \(p\) 中更常见的点,所以我们给它更大的权重,以补偿它在 \(q\) 中出现的低概率。
- 如果 \(p(x_i) < q(x_i)\):权重 < 1。这意味着 \(x_i\) 在 \(p\) 中其实不那么常见,但在 \(q\) 中却很常见。我们采到了很多这样的点,但它们对目标分布 \(p\) 的代表性不强,所以需要降低权重。
- 重要性采样的方差(和重要性权重成平方关系)
加入设重要性权重为 \(\frac{p(x_i)}{q(x_i)}=w(x)\)
现在,我们来计算这个重要性采样估计量的方差。根据方差公式 \(\text{Var}(Y) = \mathbb{E}[Y^2] - (\mathbb{E}[Y])^2\),我们可以计算估计量 \(w(x)f(x)\) 的方差:
为了看清核心关系,我们假设 \(f(x)\) 是一个常数(比如为1),来单独看权重 \(w(x)\) 带来的影响。此时估计量就是 \(w(x)\) 本身,其期望是1(\(\mathbb{E}_{x\sim q}[w(x)]=\sum_x q(x)\times \frac{p(x)}{q(x)} = \sum_x p(x)=1\))。
那么它的方差为:
看,方差就等于权重的二阶矩 \(\mathbb{E}[w^2]\) 减去1。这就是“方差与权重 \(w\) 二阶矩直接相关”的最简洁体现。
所以,如果\(p\)和\(q\)的分布相差太远,那么在某处,\(w(x)\)可能会非常大,另外一处,\(w(x)\)又会非常小(因为要保证期望为1),而大的\(w(x)\)会导致\(\mathbb{E}_{x\sim q}[w(x)^2]\)很大,进而会导致重要性采样的高方差.
大数定律
定理陈述(强大数定律 - 柯尔莫哥洛夫版)
设 \(X_{1}, X_{2}, X_{3}, \cdots\) 是独立同分布的随机变量序列,且具有有限的期望值 \(\mu = \mathbb{E}[X_{1}]\) 和有限的方差 \(\operatorname{Var}(X_{1}) = \sigma^{2} < \infty\)。
那么,样本均值几乎必然收敛于期望值 \(\mu\):
记作:\(\frac{1}{n} \sum_{i=1}^{n} X_{i} \xrightarrow{a.s.} \mu\)。
背景问题
在策略梯度中,我们需要计算:
使用旧策略 收集的数据来更新新策略 时,需要重要性采样。
重要性采样原理
重要性采样允许用旧策略的样本估计新策略的期望:
重要性权重为:
代码实现
# 步骤1: 计算重要性采样比率
negative_approx_kl = log_prob - old_log_prob # log(π_new/π_old)
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl) # ratio = π_new/π_old
数学表达式:
为什么需要重要性采样?
-
数据复用:用旧策略的数据更新新策略
-
样本效率:减少重新采样
-
在线学习:在策略更新过程中利用历史数据
问题:重要性采样可能不稳定
当 与 差异较大时, 可能很大,导致:
-
高方差
-
训练不稳定
-
策略更新过大
2. Clipped PPO(标准 PPO)
compute_policy_loss 函数详解:重要性采样、Clipped PPO 和 Dual-clip PPO
1. 重要性采样(Importance Sampling)
背景问题
在策略梯度中,我们需要计算:
使用旧策略 \(\pi_{\text{old}}\) 收集的数据来更新新策略 \(\pi_{\text{new}}\) 时,需要重要性采样。
重要性采样原理
重要性采样允许用旧策略的样本估计新策略的期望:
重要性权重为:
代码实现
# 步骤1: 计算重要性采样比率
negative_approx_kl = log_prob - old_log_prob # log(π_new/π_old)
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
ratio = torch.exp(negative_approx_kl) # ratio = π_new/π_old
数学表达式:
为什么需要重要性采样?
- 数据复用:用旧策略的数据更新新策略
- 样本效率:减少重新采样
- 在线学习:在策略更新过程中利用历史数据
问题:重要性采样可能不稳定
当 \(\pi_{\text{new}}\) 与 \(\pi_{\text{old}}\) 差异较大时,\(w(a|s)\) 可能很大,导致:
- 高方差
- 训练不稳定
- 策略更新过大
2. Clipped PPO(标准 PPO)
核心思想
限制重要性权重,防止策略更新过大。
标准策略梯度损失
PPO Clipped 损失
其中 \(\text{clip}(\text{ratio}, 1-\epsilon, 1+\epsilon)\) 将 ratio 限制在 \([1-\epsilon, 1+\epsilon]\) 范围内。
为什么取 min?
- 当 \(A(s,a) > 0\)(好动作)时:
- 如果 \(\text{ratio} > 1+\epsilon\),使用 \((1+\epsilon) \cdot A\),限制过度增加
- 如果 \(\text{ratio} < 1+\epsilon\),使用 \(\text{ratio} \cdot A\),允许正常更新
- 当 \(A(s,a) < 0\)(坏动作)时:
- 如果 \(\text{ratio} < 1-\epsilon\),使用 \((1-\epsilon) \cdot A\),限制过度减少
- 如果 \(\text{ratio} > 1-\epsilon\),使用 \(\text{ratio} \cdot A\),允许正常更新
等价形式(最大化目标,取负号后最小化):
代码实现
# 步骤2: 计算标准PPO clipped loss
pg_losses1 = -advantages * ratio # 未裁剪的损失: -ratio * A
pg_losses2 = -advantages * torch.clamp(
ratio, 1 - cliprange_low, 1 + cliprange_high # clip(ratio, 1-ε, 1+ε)
) # 裁剪后的损失: -clip(ratio, 1-ε, 1+ε) * A
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
# max(-ratio * A, -clip(ratio, 1-ε, 1+ε) * A)
数学表达式:
Clipped PPO 的优势
- 限制更新幅度:\(\epsilon\)(如 0.2)控制单步更新范围
- 降低方差:避免极端权重
- 提高稳定性:防止策略崩溃
3. Dual-clip PPO
背景问题
当 \(A(s,a) < 0\) 且 \(\text{ratio}\) 很大时,标准 PPO 可能仍允许较大的负更新,导致策略过度偏离。
核心思想
在负优势时,额外限制损失的下界,防止策略过度更新。
Dual-clip 公式
当 \(A(s,a) < 0\) 时,添加下界限制:
其中 \(c > 1\) 是下界常数(通常 \(c = 3.0\))。
为什么需要 Dual-clip?
当 \(A(s,a) < 0\) 且 \(\text{ratio}\) 很大时:
- 标准 PPO:\(L = -\text{ratio} \cdot A\) 可能很大(负优势 × 大 ratio = 大负损失)
- Dual-clip:限制为 \(-c \cdot A\),防止过度惩罚
代码实现
# 步骤3: 计算dual-clip PPO损失
pg_losses3 = -advantages * clip_ratio_c # -c * A (下界)
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
# min(-c * A, L^CLIP)
# 步骤4: 根据advantages的符号选择
pg_losses = torch.where(
advantages < 0,
clip_pg_losses2, # 负优势:使用dual-clip
clip_pg_losses1 # 正优势:使用标准PPO
)
数学表达式:
Dual-clip 的优势
- 防止过度惩罚:在负优势时限制下界
- 提高稳定性:避免策略在坏动作上过度更新
- 适合复杂任务:在奖励稀疏或噪声大时更稳健
4. 完整损失函数总结
三阶段损失计算
最终损失(聚合后):
可视化理解
对于正优势(\(A > 0\)):
- 如果 \(\text{ratio} < 1+\epsilon\):使用 \(L_1 = -\text{ratio} \cdot A\)
- 如果 \(\text{ratio} > 1+\epsilon\):使用 \(L_2 = -(1+\epsilon) \cdot A\)(裁剪)
对于负优势(\(A < 0\)):
- 如果 \(\text{ratio} > 1-\epsilon\):使用 \(L_1 = -\text{ratio} \cdot A\)
- 如果 \(\text{ratio} < 1-\epsilon\):使用 \(L_2 = -(1-\epsilon) \cdot A\)(裁剪)
- 如果 \(L^{\text{CLIP}} < -c \cdot A\):进一步限制为 \(-c \cdot A\)(dual-clip)
5. 为什么这样设计?
重要性采样
- 允许用旧数据更新新策略
- 提高样本效率
Clipped PPO
- 限制更新幅度(\(\epsilon = 0.2\))
- 降低方差,提高稳定性
Dual-clip PPO
- 在负优势时额外限制下界(\(c = 3.0\))
- 防止策略在坏动作上过度偏离
6. 对于 GRPO 的默认值
根据脚本配置:
| 参数 | 值 | 作用 |
|---|---|---|
cliprange |
0.2 |
标准 PPO 裁剪范围 \(\epsilon\) |
cliprange_low |
0.2 |
下界裁剪范围 |
cliprange_high |
0.2 |
上界裁剪范围 |
clip_ratio_c |
3.0 |
Dual-clip 下界常数 \(c\) |
loss_agg_mode |
"token-mean" |
对所有有效 token 求平均 |
7. 完整代码流程
def compute_policy_loss(
old_log_prob, # log π_old(a|s)
log_prob, # log π_new(a|s)
advantages, # A(s,a)
response_mask, # 掩码
cliprange=0.2, # ε
clip_ratio_c=3.0, # c
loss_agg_mode="token-mean"
):
# 1. 重要性采样比率
ratio = exp(log_prob - old_log_prob) # π_new/π_old
# 2. 标准PPO clipped loss
L1 = -ratio * advantages
L2 = -clip(ratio, 1-ε, 1+ε) * advantages
L_CLIP = max(L1, L2)
# 3. Dual-clip(仅当advantages < 0时)
L3 = -c * advantages
L_DUAL = where(advantages < 0, min(L_CLIP, L3), L_CLIP)
# 4. 聚合为标量
loss = mean(L_DUAL, mask=response_mask)
return loss, clipfrac, kl, clipfrac_lower
总结
- 重要性采样:用旧策略数据更新新策略
- Clipped PPO:限制更新幅度(\(\epsilon = 0.2\))
- Dual-clip PPO:在负优势时额外限制下界(\(c = 3.0\))
三者结合,在保证样本效率的同时,提高训练稳定性,适用于 GRPO 等强化学习算法。
kl_penalty 函数完整讲解(K1,K2,K3)
1. 函数概述
kl_penalty 用于计算当前策略 π_actor 与参考策略 π_ref 之间的 KL 散度,作为惩罚项加入损失,防止策略更新过大。
2. KL 散度的标准定义
对于两个概率分布 π_ref 和 π_actor,KL 散度定义为:
展开后:
其中:
ref_logprob = log π_ref(a|s)logprob = log π_actor(a|s)
3. 三种 KL 散度估计方法(K1、K2、K3)
K1 估计器(一阶近似,高方差)
公式:
代码实现:
if kl_penalty in ("kl", "k1"):
return logprob - ref_logprob
# 注意:代码中返回的是 logprob - ref_logprob
# 但标准KL应该是 ref_logprob - logprob
# 这里可能是为了与损失函数中的符号一致(损失中会取负号)
特点:
- 无偏估计
- 方差较高,可能不稳定
- 计算简单
K2 估计器(二阶近似,低方差)
公式:
代码实现:
if kl_penalty in ("mse", "k2"):
return 0.5 * (logprob - ref_logprob).square()
# = 0.5 * (logprob - ref_logprob)²
特点:
- 有偏估计(近似)
- 方差较低,更稳定
- 使用平方项平滑估计
K3 估计器(低方差近似,推荐)
公式推导:
使用泰勒展开近似 KL 散度:
定义比率:
K3 估计器公式:
展开为:
代码实现:
if kl_penalty in ("low_var_kl", "k3"):
kl = ref_logprob - logprob # log(π_ref/π_actor)
kl = torch.clamp(kl, min=-20, max=20) # 数值稳定性裁剪
ratio = torch.exp(kl) # ratio = π_ref/π_actor
kld = (ratio - kl - 1) # K3公式
return torch.clamp(kld, min=-10, max=10) # 最终裁剪
数学表达式:
其中:
- \(\text{ref\_logprob} - \text{logprob} = \log \frac{\pi_{\text{ref}}}{\pi_{\text{actor}}}\)
- \(\exp(\text{ref\_logprob} - \text{logprob}) = \frac{\pi_{\text{ref}}}{\pi_{\text{actor}}}\)
特点:
- 无偏估计(在策略接近时)
- 低方差,更稳定
- 适合 PPO 训练
- 参考:http://joschu.net/blog/kl-approx.html
4. 其他方法
"abs"(绝对值)
代码:
if kl_penalty == "abs":
return (logprob - ref_logprob).abs()
5. 完整代码实现
def kl_penalty(
logprob: torch.FloatTensor, # 形状: (bs, response_length)
ref_logprob: torch.FloatTensor, # 形状: (bs, response_length)
kl_penalty: str # "kl"/"k1", "mse"/"k2", "low_var_kl"/"k3", "abs"
) -> torch.FloatTensor: # 返回: (bs, response_length)
"""
计算KL散度惩罚项
Args:
logprob: 当前策略的对数概率 log π_actor(a|s)
ref_logprob: 参考策略的对数概率 log π_ref(a|s)
kl_penalty: KL散度计算类型
Returns:
KL散度估计值,形状与输入相同
"""
# K1估计器:一阶近似
if kl_penalty in ("kl", "k1"):
# KL ≈ log π_actor - log π_ref
# 注意:这里返回的是负的KL(因为损失函数中会取负号)
return logprob - ref_logprob
# 绝对值方法
if kl_penalty == "abs":
# KL ≈ |log π_actor - log π_ref|
return (logprob - ref_logprob).abs()
# K2估计器:二阶近似
if kl_penalty in ("mse", "k2"):
# KL ≈ 0.5 * (log π_actor - log π_ref)²
return 0.5 * (logprob - ref_logprob).square()
# K3估计器:低方差近似(推荐)
if kl_penalty in ("low_var_kl", "k3"):
# kl = log(π_ref/π_actor) = ref_logprob - logprob
kl = ref_logprob - logprob
# 数值稳定性裁剪:防止exp溢出
kl = torch.clamp(kl, min=-20, max=20)
# ratio = π_ref/π_actor = exp(log(π_ref/π_actor))
ratio = torch.exp(kl)
# K3公式:KL = ratio - log(ratio) - 1
# 其中 log(ratio) = kl = ref_logprob - logprob
kld = (ratio - kl - 1)
# 最终裁剪:防止数值不稳定
return torch.clamp(kld, min=-10, max=10)
# 完整KL散度(需要完整logits,未实现)
if kl_penalty == "full":
raise NotImplementedError
raise NotImplementedError(f"Unknown kl_penalty type: {kl_penalty}")
6. 在 PPO 训练中的使用
在 update_policy 中的使用流程:
# 1. 计算KL散度(每个token位置)
kld = kl_penalty(
logprob=log_prob, # 当前策略的对数概率
ref_logprob=ref_log_prob, # 参考策略的对数概率
kl_penalty="low_var_kl" # 使用K3估计器
) # 形状: (bs, response_length)
# 2. 聚合为标量损失
kl_loss = agg_loss(
loss_mat=kld, # KL散度矩阵
loss_mask=response_mask, # 响应掩码
loss_agg_mode="token-mean" # 对所有有效token求平均
) # 标量
# 3. 添加到总损失中
policy_loss = policy_loss + kl_loss * kl_loss_coef
# = kl_loss * 0.001
7. 对于 run_qwen3-8b.sh 脚本
脚本配置:
actor_rollout_ref.actor.kl_loss_type=low_var_kl # 使用K3估计器
actor_rollout_ref.actor.kl_loss_coef=0.001 # KL损失系数
因此使用 K3 估计器:
8. 三种估计器的比较
| 估计器 | 公式 | 无偏性 | 方差 | 稳定性 | 推荐度 |
|---|---|---|---|---|---|
| K1 | \(\log \pi_{\text{ref}} - \log \pi_{\text{actor}}\) | ✅ 无偏 | ❌ 高 | ❌ 不稳定 | ⭐⭐ |
| K2 | \(\frac{1}{2}(\log \pi_{\text{ref}} - \log \pi_{\text{actor}})^2\) | ❌ 有偏 | ✅ 低 | ✅ 稳定 | ⭐⭐⭐ |
| K3 | \(\exp(\Delta) - \Delta - 1\),其中 \(\Delta = \log \pi_{\text{ref}} - \log \pi_{\text{actor}}\) | ✅ 无偏 | ✅ 低 | ✅ 稳定 | ⭐⭐⭐⭐⭐ |
9. 数学推导(K3 的推导)
K3 基于以下近似:
对于小的 \(x = \log \frac{\pi_{\text{ref}}}{\pi_{\text{actor}}}\),使用泰勒展开:
KL 散度可以写成:
使用恒等式 \(x = e^x - 1 - (e^x - 1 - x)\),当 \(x\) 接近 0 时:
因此:
这就是 K3 估计器的公式。
总结
- K1:简单但方差高
- K2:稳定但有偏
- K3:无偏且低方差,适合 PPO
脚本使用 K3(low_var_kl),系数为 0.001,有助于稳定训练并限制策略更新幅度。

浙公网安备 33010602011771号