大模型- 强化学习-PPO算法的实现--92
参考
https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/
https://newfacade.github.io/notes-on-reinforcement-learning/16-ppo-imp.html
https://gemini.google.com/app/247cc5d3d5bad7de
上一讲学习的 PPO 理论知识,特别是 PPO-Clip 的思想,完全转化为可以运行的代码。
三个核心部分组成:
RolloutBuffer: 一个缓冲区,用于存储智能体与环境交互时产生的数据。
ActorCritic: 包含策略网络(Actor)和价值网络(Critic)的模型。
PPO Agent: 封装了所有训练逻辑的主类,包括数据收集、优势计算和模型更新。
RolloutBuffer 类 - 数据存储仓库
PPO 是一种在线策略 (On-Policy) 算法,它需要先收集一批数据,然后用这批数据进行训练,训练完后就将这批数据丢弃,再收集新的一批。RolloutBuffer 就是为此设计的。
与DQN的 ReplayBuffer 的区别: DQN 的 ReplayBuffer 会存储很多旧的数据,并且在训练时随机采样,这是离线策略 (Off-Policy) 的特点。而 PPO 的 RolloutBuffer 只存储当前策略产生的一批数据,用完即弃。
这个类本质上就是一组列表,用来存储一个批次(比如2048个时间步)中每一步的状态、动作、动作的对数概率、奖励和终止状态信息。
class RolloutBuffer:
def __init__(self):
# 初始化空的列表来存储轨迹数据
self.actions = []
self.states = []
self.logprobs = []
self.rewards = []
self.is_terminals = []
def clear(self):
# 在每次更新后清空所有数据
del self.actions[:]
del self.states[:]
del self.logprobs[:]
del self.rewards[:]
del self.is_terminals[:]
ActorCritic 类 - PPO 的大脑
这个模型结构与我们之前在 Actor-Critic 中学到的非常相似,通常是一个共享底层网络、拥有两个不同输出头(一个给 Actor,一个给 Critic)的结构。
import torch
import torch.nn as nn
from torch.distributions import Categorical
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim):
super(ActorCritic, self).__init__()
# Actor 网络 (策略网络)
self.actor = nn.Sequential(
nn.Linear(state_dim, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, action_dim),
nn.Softmax(dim=-1)
)
# Critic 网络 (价值网络)
self.critic = nn.Sequential(
nn.Linear(state_dim, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1) # 输出一个标量,代表状态价值 V(s)
)
def act(self, state):
"""
用于数据收集阶段:根据单个状态决定动作。
"""
# 计算动作概率
action_probs = self.actor(state)
dist = Categorical(action_probs)
# 从分布中采样一个动作
action = dist.sample()
# 计算该动作的对数概率
action_logprob = dist.log_prob(action)
# Critic 对当前状态的价值评估
state_val = self.critic(state)
# 返回动作、其对数概率和状态价值,用于存入 Buffer
return action.detach(), action_logprob.detach(), state_val.detach()
def evaluate(self, state, action):
"""
用于更新阶段:重新评估一批旧数据。
"""
# 计算这批状态下的动作概率
action_probs = self.actor(state)
dist = Categorical(action_probs)
# 计算这批旧动作在当前策略下的对数概率
action_logprobs = dist.log_prob(action)
# 计算当前策略的熵,用于鼓励探索
dist_entropy = dist.entropy()
# Critic 对这批状态的价值评估
state_values = self.critic(state)
# 返回评估结果,用于计算 PPO 损失
return action_logprobs, state_values, dist_entropy
act vs evaluate:
- act 用于“玩游戏”:输入一个当前状态,输出一个要执行的动作,以及这个动作的 logprob 和当前状态的 value,这些都会被存起来。--数据收集
- evaluate 用于“复盘反思”:在更新阶段,我们拿出之前存好的一批 (state, action),用当前最新的网络去重新评估:
在这些 state 下,执行这些旧 action 的新 logprob 是多少?
这些 state 的新 value 是多少?以及当前策略的熵是多少?这些是计算 PPO 损失所必需的。
3. PPO Agent 的 update 方法 - 核心训练逻辑
这是 PPO 算法的精髓所在,它实现了“多轮小批量优化”的流程。
class PPO:
def __init__(self, ..., K_epochs, eps_clip, ...):
# ... 初始化 ActorCritic 模型、优化器、超参数等 ...
self.K_epochs = K_epochs # 在同一批数据上优化的轮数 K
self.eps_clip = eps_clip # 裁剪参数 ε
self.policy = ActorCritic(...) # 当前策略
self.policy_old = ActorCritic(...) # 旧策略,用于数据收集
self.policy_old.load_state_dict(self.policy.state_dict())
self.buffer = RolloutBuffer()
def update(self):
# --- 步骤 1: 使用 GAE 计算优势和回报 ---
rewards = []
discounted_reward = 0
# 从后往前遍历收集到的数据
for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
if is_terminal:
discounted_reward = 0
# G_t = r_t + γ * G_{t+1}
discounted_reward = reward + (self.gamma * discounted_reward)
rewards.insert(0, discounted_reward)
# 标准化回报
rewards = torch.tensor(rewards, dtype=torch.float32)
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)
# 将 buffer 中的数据转换为 PyTorch 张量
old_states = torch.stack(self.buffer.states, dim=0).detach()
old_actions = torch.stack(self.buffer.actions, dim=0).detach()
old_logprobs = torch.stack(self.buffer.logprobs, dim=0).detach()
# --- 步骤 2: 在同一批数据上进行 K 轮优化 ---
for _ in range(self.K_epochs):
# 调用 evaluate 函数,用当前策略 π_θ 重新评估旧数据
logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
# 计算 V(s) 和 V_target 的差距 (用于 Critic loss)
# 注意 state_values 要调整维度匹配 rewards
advantages = rewards - state_values.detach()
# --- 步骤 3: 计算 PPO-Clip 损失 ---
# 计算概率比率 r_t(θ) = exp(log π_θ(a|s) - log π_{θ_old}(a|s))
ratios = torch.exp(logprobs - old_logprobs.detach())
# 计算 Surrogate Loss (L_CLIP)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
# 最终的损失是策略损失、价值损失和熵奖励的加权和
# policy_loss: -min(surr1, surr2)
# value_loss: MSE(V(s), G_t)
# entropy_bonus: -entropy
loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
# --- 步骤 4: 梯度下降 ---
self.optimizer.zero_grad()
loss.mean().backward()
self.optimizer.step()
# --- 步骤 5: 更新旧策略的权重,为下一次数据收集做准备 ---
self.policy_old.load_state_dict(self.policy.state_dict())
# 清空 buffer
self.buffer.clear()
ratios = torch.exp(logprobs - old_logprobs.detach()) 这一行计算了核心的概率比率
surr1 = ratios * advantages 这就是裁剪目标的第一部分
surr2 = torch.clamp(...) * advantages: 这就是裁剪目标的第二部分
loss = -torch.min(surr1, surr2) 完整地实现了L^CLIP
0.5 * self.MseLoss(state_values, rewards) 实现了价值损失 L^VF
0.01 * dist_entropy: 实现了熵奖励
这份代码是 PPO 算法一个非常典型且高效的实现。它清晰地展示了 PPO 的几个核心特点:
Actor-Critic 架构:同时学习策略和价值函数。
On-Policy 数据收集:使用一个 RolloutBuffer 来收集和管理数据,用完即弃。
GAE 计算优势:虽然在这份简化代码里直接用了 G_t 作为 V_target 并计算了优势,但实际 PPO 实现通常会集成 GAE 来获得更稳定的优势估计
Clipped Surrogate Objective通过 min 和 clamp 操作,稳健地限制了策略更新的幅度
多轮小批量优化: 在同一批数据上训练 K 轮,极大地提高了数据利用率。
完整代码
from dataclasses import dataclass
@dataclass
class Args:
env_id: str = "CartPole-v1"
"""the id of the environment"""
total_timesteps: int = 200000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 4
"""the number of parallel game environments"""
num_steps: int = 128
"""the number of steps to run in each environment per policy rollout"""
num_minibatches: int = 4
"""the number of mini-batches"""
update_epochs: int = 4
"""the K epochs to update the policy"""
gamma: float = 0.99
"""the discount factor gamma"""
gae_lambda: float = 0.95
"""the lambda for the general advantage estimation"""
clip_coef: float = 0.2
"""the surrogate clipping coefficient"""
vf_coef: float = 0.5
"""coefficient of the value function"""
max_grad_norm: float = 0.5
"""the maximum norm for the gradient clipping"""
# to be filled in runtime
batch_size: int = 0
"""the batch size (computed in runtime)"""
minibatch_size: int = 0
"""the mini-batch size (computed in runtime)"""
num_iterations: int = 0
"""the number of iterations (computed in runtime)"""
args = Args()
import gymnasium as gym
# Vectorized environment that serially runs multiple environments.
envs = gym.vector.make(args.env_id, num_envs=args.num_envs, asynchronous=False)
# envs
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
# Described in `Exact solutions to the nonlinear dynamics of learning in deep linear neural networks`
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class Agent(nn.Module):
def __init__(self, envs):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 1), std=1.0),
)
self.actor = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
)
def get_value(self, x):
return self.critic(x)
def get_action_and_value(self, x, action=None):
logits = self.actor(x)
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), self.critic(x)
agent = Agent(envs=envs)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape)
logprobs = torch.zeros((args.num_steps, args.num_envs))
rewards = torch.zeros((args.num_steps, args.num_envs))
dones = torch.zeros((args.num_steps, args.num_envs))
values = torch.zeros((args.num_steps, args.num_envs))
# start the game
next_obs, _ = envs.reset(seed=1) # (num_envs, observation_space)
next_obs = torch.Tensor(next_obs)
next_done = torch.zeros(args.num_envs)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size
def train():
for iteration in range(1, args.num_iterations + 1):
# collect data
for step in range(args.num_steps):
obs[step] = next_obs
dones[step] = next_done
# ALGO LOGIC: action logic
with torch.no_grad():
action, logprob, value = agent.get_action_and_value(next_obs)
values[step] = value.flatten() # (num_envs, 1)
actions[step] = action # (num_envs)
logprobs[step] = logprob # (num_envs)
# execute the game
next_obs, reward, terminations, truncations, infos = envs.step(action.numpy())
next_done = np.logical_or(terminations, truncations)
next_obs, next_done = torch.Tensor(next_obs), torch.Tensor(next_done)
rewards[step] = torch.tensor(reward).view(-1)
# compute advantages and rewards-to-go
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
# flatten the batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_logprobs = logprobs.reshape(-1)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
# Optimizing the policy and value network
b_inds = np.arange(args.batch_size)
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
mb_inds = b_inds[start: end]
_, newlogprob, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()
with torch.no_grad():
# calculate approx_kl http://joschu.net/blog/kl-approx.html
approx_kl = ((ratio - 1) - logratio).mean()
mb_advantages = b_advantages[mb_inds]
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
# Policy loss
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss
newvalue = newvalue.view(-1)
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
v_clipped = b_values[mb_inds] + torch.clamp(
newvalue - b_values[mb_inds],
-args.clip_coef,
args.clip_coef)
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
loss = pg_loss + v_loss * args.vf_coef
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
if epoch == args.update_epochs - 1 and iteration % 50 == 0:
print(iteration, loss.item())
envs.close()
if __name__ == "__main__":
train()

浙公网安备 33010602011771号