PPO和GRPO算法详解(持续更新中)

PPO

众所周知,PPO在LLM应用下, t时刻下,State就变成了query+output(<t)

Reference是初始模型,举个例子可以认为是deepseek V3+SFT之后的模型,是不变的,是fozen model

从头开始推导一次:

t0时刻:


query输入到policy model里面,生成一个token,记为o1

同时,o1应该有一个对应的即时奖励rψ,通过reward model进行计算(人工设置)

而rt还要带一个通过当前policy和reference的分布计算一下KL散度,rt可以理解为这一步所获得的总即时奖励


Value Model对当前环境(Query+o1)进行评估

生成一个对当前状态的评估Vt(Value价值网络输入的是动作前状态,也就是Query)

以及对这个动作之后未来状态的评估Vt+1, 然后计算δt (Value价值网络输入的是动作后状态,也就是Query+o1)

r指的是即时奖励,V指的是上面的Vt和Vt+1,进行计算δ后计算优势函数A


我们注意到GAE指回去了Value Model,因为要用来计算Value网络的loss函数,更新Value网络

Vθ(St)是当前状态的估计奖励,rt+γVθ(St+1)是真实奖励,计算其平方差值,Vθ(St+1)和Vθ(St)就是上文提到的Vt+1和Vt


通过优势函数A和之前计算出来的KL散度,回去计算Policy Model的loss函数,完成更新

那么又一个问题出现了,Πθ和Πθold是怎么来的,当前只有一个policy网络,是怎么查找到两个概率分布的,在第一个时间t0里面,我们直接认为这两个数值是相等的,比值是1,计算完成后更新网络

要注意的是,更新之前buffer里面会存储一个值,就是当前policy model下对于状态(query+o1)下分布的存储,用于下一次计算,作为Πθold

也就是说policy model在这一轮里面使用了两次,第一次是动作预测(选择o1),第二次是当前状态下动作分布存储


t1时刻:

policy model生成下一个token,记为o2

后面的部分不再赘述,直接到更新网络的时刻

Πθ就是当前policy model下概率分布,取这个时刻o2的概率

Πθold就是上一次我们buffer存储的那个值,是上一个时刻policy model下对于状态(query+o1)下分布的存储,注意这个公式后面的状态是query+o1,没用o2

然后这一轮更新前也要存储一下分布,用于下一次计算


GRPO

posted @ 2025-02-26 14:34  Bronya_Silverwing  阅读(8)  评论(0)    收藏  举报