DPO: Direct Preference Optimization 直接偏好优化(学习笔记)

学习参考:链接1  

一、为什么要提出DPO

  在之前,我们已经了解到基于人类反馈的强化学习RLHF分为三个阶段:全监督微调(SFT)、奖励模型(RM)、强化学习(PPO)。但是RLHF面临缺陷:RLHF 是一个复杂且经常不稳定的过程,首先拟合反映人类偏好的奖励模型,然后使用强化学习微调大型无监督 LM,以最大化这种估计奖励,而不会偏离原始模型太远。为解决这一问题,提出一个直接偏好优化 (DPO) 的新算法:通过利用奖励函数与最优策略之间的映射关系,证明这个受限的奖励最大化问题可以通过单阶段的策略训练来精确优化,本质上是在人类偏好数据上解决一个分类问题。DPO是稳定的、性能和计算成本轻量级的,无需拟合奖励模型,在微调期间从 LM 中采样,或执行显着的超参数调整。通过实验表明:DPO 进行微调超过了 RLHF 效果,并提高了摘要和单轮对话的响应质量。

二、什么是DPO

DPO,一种基于人类偏好优化语言模型的新方法。与RLHF不同,DPO不依赖于明确的奖励建模或强化学习。它针对与RLHF相同的目标,但提供了一种更简单、更直接的培训方法。

DPO的工作原理:增加偏好样本的对数概率与减小非偏好样本响应的对数概率。它结合了动态加权机制,以避免仅使用概率比目标时遇到的模型退化问题。

DPO依赖于理论上的偏好模型,如Bradley-Terry模型,来测量奖励函数与经验偏好数据的对齐程度。与传统方法不同,传统方法使用偏好模型来训练奖励模型,然后基于该奖励模型训练策略,DPO直接根据策略定义偏好损失。给定一个关于模型响应的人类偏好数据集,DPO可以使用简单的二元交叉熵目标来优化策略,无需在训练过程中明确学习奖励函数或从策略中采样具体推导见链接1  

(1)原RLHF的优化目标:最大化奖励和最小化参考策略的KL散度

(2)DPO优化目标:利用了从奖励函数到最优策略的解析映射,允许直接使用人类偏好数据进行简化的优化过程

该目标增加了对偏好数据$y_w$的可能性,并减少了非偏好数据$y_l$的可能性。这些示例按照隐式奖励模型的评级加权,由$\beta$缩放.

DPO重参数化等效于具有隐式奖励函数

参数模型$\pi_{\theta}$的优化等效于在此变量更改下的奖励模型优化。

(3)DPO在干什么?

为了从原理上理解 DPO,分析损失函数的梯度$L_{DPO} $。 相对于参数 θ 的梯度可以写为:

其中是由语言模型$\pi_{\theta}$和参考模型$\pi_{ref}$隐式定义的奖励函数。直观上,损失函数 $L_{DPO} $的梯度增加了偏好$y_w$ 的可能性,并降低了非偏好$y_l$的可能性。更重要的是,样例的权重是通过: 隐式奖励模型$\hat{r}_{\theta}$对非偏好的评分高多少来衡量的,即$\hat{r}_{\theta}(x,y_l)-\hat{r}_{\theta}(x,y_w)$,按 β 进行缩放,即隐式奖励模型认为策略模型错误的程度。 我们的实验表明了这种加权的重要性,因为没有加权系数的这种方法的简单版本可能会导致语言模型退化。

(4)DPO outline

步骤1)是在构造数据集,通过对同一问题的两种回复的倾向性:chosen or rejected,反映人类偏好。

步骤2)在于优化,具体过程大概是,对于同一个question prompt,模型在两种模型:language/policy model 和 reference model下分别生成,对应chosen 和 rejected label真值标签的生成概率,因此可以获得四种概率值policy_chosen_logpspolicy_rejected_logpsreference_chosen_logpsreference_rejected_logps, 用于DPO loss计算。

(5) Huagging Face DPO Trainer 

1、DPO trainer 期望数据集具有非常特定的格式。 给定两个句子时,模型将被训练为直接优化偏好:那一个句子最相关。

数据集由三部分组成:

  • prompt
  • chosen
  • rejected

可以由prompt 模板: Human: prompt. Assistant: chosen/rejected 构成如下数据:Anthropic/hh-rlhf dataset

 2、 预期模型格式

与 PPO 期望 AutoModelForCausalLMWithValueHead 作为值函数相比,DPO 训练器期望 AutoModelForCausalLM 模型。

3、使用 DPOTrainer 源码
有关详细示例,请查看 Examples/scripts/dpo.py 脚本。 在较高级别上,我们需要使用我们希望训练的模型、参考 ref_model 来初始化 DPOTrainer,我们将使用它来计算首选和拒绝响应的隐式奖励,beta 指隐式奖励的超参数, 数据集包含上面列出的 3 个条目。 请注意,模型和 ref_model 需要具有相同的架构(即仅解码器或编码器-解码器)。

 dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

之后就可以调用:

dpo_trainer.train()

请注意,β 是 DPO 损失的温度参数,通常在 0.1 到 0.5 范围内。 当beta -> 0 ,意味着忽略参考模型。

4、损失函数

给定偏好数据,我们可以根据 Bradley-Terry 模型拟合二元分类器,事实上,DPO 作者通过 Logsigmoid 提出标准化似然的 sigmoid 损失来拟合逻辑回归。

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        reference_free: bool = False,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute the DPO loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
            reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the DPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        if reference_free:
            ref_logratios = 0
        else:
            ref_logratios = reference_chosen_logps - reference_rejected_logps

        pi_logratios = pi_logratios.to(self.accelerator.device)
        ref_logratios = ref_logratios.to(self.accelerator.device)
        logits = pi_logratios - ref_logratios

        # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
        # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
        # calculates a conservative DPO loss.
        if self.loss_type == "sigmoid":
            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
            )
        elif self.loss_type == "hinge":
            losses = torch.relu(1 - self.beta * logits)
        elif self.loss_type == "ipo":
            # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
            losses = (logits - 1 / (2 * self.beta)) ** 2
        elif self.loss_type == "kto_pair":
            # eqn (7) of the HALOs paper
            chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
            rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)

            chosen_logratios = policy_chosen_logps - reference_chosen_logps
            rejected_logratios = policy_rejected_logps - reference_rejected_logps
            # As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
            losses = torch.cat(
                (
                    1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
                    1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
                ),
                0,
            )
        else:
            raise ValueError(
                f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
            )

        chosen_rewards = (
            self.beta
            * (
                policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
            ).detach()
        )
        rejected_rewards = (
            self.beta
            * (
                policy_rejected_logps.to(self.accelerator.device)
                - reference_rejected_logps.to(self.accelerator.device)
            ).detach()
        )

        return losses, chosen_rewards, rejected_rewards

 其他改进的损失函数:

RSO 作者建议在 SLiC 论文中的归一化似然上使用 hinge损失。 DPOTrainer 可以通过 loss_type="hinge" 参数切换到此损失,这种情况下的 beta 是margin的倒数。

IPO 作者对 DPO 算法提供了更深入的理论理解,并识别了过度拟合的问题,并提出了一种替代损失,可以通过训练器的 loss_type="ipo" 参数来使用。

cDPO 是对 DPO 损失的调整,其中我们假设偏好标签有一定的噪声,可以通过 label_smoothing 参数(0 到 0.5 之间)传递到 DPOTrainer,然后使用保守的 DPO 损失。 使用 loss_type="cdpo" 参数给训练器来使用它。

KTO 损失的导出是为了直接最大化 LLM 代的效用,而不是偏好的对数似然。 因此,数据集不一定是偏好,而是期望的完成与不期望的完成。 对于 DPOTrainer 所需的配对偏好数据,请使用训练器的 loss_type="kto_pair" 参数来利用此损失,而对于所需和不需要的数据的更一般情况,请使用尚未实现的 KTOTrainer。

5、指标:在训练和评估时,记录以下奖励指标:

  • rewards/chosen: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
  • rewards/rejected: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
  • rewards/accuracies: mean of how often the chosen rewards are > than the corresponding rejected rewards
  • rewards/margins: the mean difference between the chosen and corresponding rejected rewards
    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
        ) = self.concatenated_forward(model, batch)

        # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
        if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
            reference_chosen_logps = batch["reference_chosen_logps"]
            reference_rejected_logps = batch["reference_rejected_logps"]
        else:
            with torch.no_grad():
                if self.ref_model is None:
                    with self.null_ref_context():
                        (
                            reference_chosen_logps,
                            reference_rejected_logps,
                            _,
                            _,
                        ) = self.concatenated_forward(self.model, batch)
                else:
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                    ) = self.concatenated_forward(self.ref_model, batch)

        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
        )
        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()

        return losses.mean(), metrics

 

posted @ 2024-01-15 09:56  kkzhang  阅读(3320)  评论(0编辑  收藏  举报