大模型- 强化学习-PPO算法的实现--92

参考

https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/
https://newfacade.github.io/notes-on-reinforcement-learning/16-ppo-imp.html
https://gemini.google.com/app/247cc5d3d5bad7de

上一讲学习的 PPO 理论知识,特别是 PPO-Clip 的思想,完全转化为可以运行的代码。

三个核心部分组成:
RolloutBuffer: 一个缓冲区,用于存储智能体与环境交互时产生的数据。
ActorCritic: 包含策略网络(Actor)和价值网络(Critic)的模型。
PPO Agent: 封装了所有训练逻辑的主类,包括数据收集、优势计算和模型更新。

RolloutBuffer 类 - 数据存储仓库

PPO 是一种在线策略 (On-Policy) 算法,它需要先收集一批数据,然后用这批数据进行训练,训练完后就将这批数据丢弃,再收集新的一批。RolloutBuffer 就是为此设计的。
与DQN的 ReplayBuffer 的区别: DQN 的 ReplayBuffer 会存储很多旧的数据,并且在训练时随机采样,这是离线策略 (Off-Policy) 的特点。而 PPO 的 RolloutBuffer 只存储当前策略产生的一批数据,用完即弃。
这个类本质上就是一组列表,用来存储一个批次(比如2048个时间步)中每一步的状态、动作、动作的对数概率、奖励和终止状态信息。

class RolloutBuffer:
    def __init__(self):
        # 初始化空的列表来存储轨迹数据
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []

    def clear(self):
        # 在每次更新后清空所有数据
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]

ActorCritic 类 - PPO 的大脑

这个模型结构与我们之前在 Actor-Critic 中学到的非常相似,通常是一个共享底层网络、拥有两个不同输出头(一个给 Actor,一个给 Critic)的结构。

import torch
import torch.nn as nn
from torch.distributions import Categorical

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()

        # Actor 网络 (策略网络)
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )

        # Critic 网络 (价值网络)
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1) # 输出一个标量,代表状态价值 V(s)
        )

    def act(self, state):
        """
        用于数据收集阶段:根据单个状态决定动作。
        """
        # 计算动作概率
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        # 从分布中采样一个动作
        action = dist.sample()
        # 计算该动作的对数概率
        action_logprob = dist.log_prob(action)
        # Critic 对当前状态的价值评估
        state_val = self.critic(state)

        # 返回动作、其对数概率和状态价值,用于存入 Buffer
        return action.detach(), action_logprob.detach(), state_val.detach()

    def evaluate(self, state, action):
        """
        用于更新阶段:重新评估一批旧数据。
        """
        # 计算这批状态下的动作概率
        action_probs = self.actor(state)
        dist = Categorical(action_probs)

        # 计算这批旧动作在当前策略下的对数概率
        action_logprobs = dist.log_prob(action)
        # 计算当前策略的熵,用于鼓励探索
        dist_entropy = dist.entropy()
        # Critic 对这批状态的价值评估
        state_values = self.critic(state)

        # 返回评估结果,用于计算 PPO 损失
        return action_logprobs, state_values, dist_entropy

act vs evaluate:

  • act 用于“玩游戏”:输入一个当前状态,输出一个要执行的动作,以及这个动作的 logprob 和当前状态的 value,这些都会被存起来。--数据收集
  • evaluate 用于“复盘反思”:在更新阶段,我们拿出之前存好的一批 (state, action),用当前最新的网络去重新评估
    在这些 state 下,执行这些旧 action 的新 logprob 是多少?
    这些 state 的新 value 是多少?以及当前策略的熵是多少?这些是计算 PPO 损失所必需的。

3. PPO Agent 的 update 方法 - 核心训练逻辑

这是 PPO 算法的精髓所在,它实现了“多轮小批量优化”的流程。

class PPO:
    def __init__(self, ..., K_epochs, eps_clip, ...):
        # ... 初始化 ActorCritic 模型、优化器、超参数等 ...
        self.K_epochs = K_epochs          # 在同一批数据上优化的轮数 K
        self.eps_clip = eps_clip          # 裁剪参数 ε
        self.policy = ActorCritic(...)    # 当前策略
        self.policy_old = ActorCritic(...) # 旧策略,用于数据收集
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.buffer = RolloutBuffer()

    def update(self):
        # --- 步骤 1: 使用 GAE 计算优势和回报 ---
        rewards = []
        discounted_reward = 0
        # 从后往前遍历收集到的数据
        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            # G_t = r_t + γ * G_{t+1}
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
        
        # 标准化回报
        rewards = torch.tensor(rewards, dtype=torch.float32)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        # 将 buffer 中的数据转换为 PyTorch 张量
        old_states = torch.stack(self.buffer.states, dim=0).detach()
        old_actions = torch.stack(self.buffer.actions, dim=0).detach()
        old_logprobs = torch.stack(self.buffer.logprobs, dim=0).detach()

        # --- 步骤 2: 在同一批数据上进行 K 轮优化 ---
        for _ in range(self.K_epochs):
            # 调用 evaluate 函数,用当前策略 π_θ 重新评估旧数据
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # 计算 V(s) 和 V_target 的差距 (用于 Critic loss)
            # 注意 state_values 要调整维度匹配 rewards
            advantages = rewards - state_values.detach()

            # --- 步骤 3: 计算 PPO-Clip 损失 ---
            # 计算概率比率 r_t(θ) = exp(log π_θ(a|s) - log π_{θ_old}(a|s))
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # 计算 Surrogate Loss (L_CLIP)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
            
            # 最终的损失是策略损失、价值损失和熵奖励的加权和
            # policy_loss: -min(surr1, surr2)
            # value_loss: MSE(V(s), G_t)
            # entropy_bonus: -entropy
            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy
            
            # --- 步骤 4: 梯度下降 ---
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        # --- 步骤 5: 更新旧策略的权重,为下一次数据收集做准备 ---
        self.policy_old.load_state_dict(self.policy.state_dict())

        # 清空 buffer
        self.buffer.clear()

ratios = torch.exp(logprobs - old_logprobs.detach()) 这一行计算了核心的概率比率
surr1 = ratios * advantages 这就是裁剪目标的第一部分
surr2 = torch.clamp(...) * advantages: 这就是裁剪目标的第二部分
loss = -torch.min(surr1, surr2) 完整地实现了L^CLIP
0.5 * self.MseLoss(state_values, rewards) 实现了价值损失 L^VF
0.01 * dist_entropy: 实现了熵奖励

这份代码是 PPO 算法一个非常典型且高效的实现。它清晰地展示了 PPO 的几个核心特点:
Actor-Critic 架构:同时学习策略和价值函数。
On-Policy 数据收集:使用一个 RolloutBuffer 来收集和管理数据,用完即弃。
GAE 计算优势:虽然在这份简化代码里直接用了 G_t 作为 V_target 并计算了优势,但实际 PPO 实现通常会集成 GAE 来获得更稳定的优势估计
Clipped Surrogate Objective通过 min 和 clamp 操作,稳健地限制了策略更新的幅度
多轮小批量优化: 在同一批数据上训练 K 轮,极大地提高了数据利用率。

完整代码

from dataclasses import dataclass

@dataclass
class Args:
    env_id: str = "CartPole-v1"
    """the id of the environment"""
    total_timesteps: int = 200000
    """total timesteps of the experiments"""
    learning_rate: float = 2.5e-4
    """the learning rate of the optimizer"""
    num_envs: int = 4
    """the number of parallel game environments"""
    num_steps: int = 128
    """the number of steps to run in each environment per policy rollout"""
    num_minibatches: int = 4
    """the number of mini-batches"""
    update_epochs: int = 4
    """the K epochs to update the policy"""
    
    gamma: float = 0.99
    """the discount factor gamma"""
    gae_lambda: float = 0.95
    """the lambda for the general advantage estimation"""
    clip_coef: float = 0.2
    """the surrogate clipping coefficient"""
    vf_coef: float = 0.5
    """coefficient of the value function"""
    max_grad_norm: float = 0.5
    """the maximum norm for the gradient clipping"""

    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_iterations: int = 0
    """the number of iterations (computed in runtime)"""

args = Args()


import gymnasium as gym

# Vectorized environment that serially runs multiple environments.
envs = gym.vector.make(args.env_id, num_envs=args.num_envs, asynchronous=False)
# envs


import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    # Described in `Exact solutions to the nonlinear dynamics of learning in deep linear neural networks`
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
        )

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), self.critic(x)

agent = Agent(envs=envs)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape)
logprobs = torch.zeros((args.num_steps, args.num_envs))
rewards = torch.zeros((args.num_steps, args.num_envs))
dones = torch.zeros((args.num_steps, args.num_envs))
values = torch.zeros((args.num_steps, args.num_envs))

# start the game
next_obs, _ = envs.reset(seed=1)  # (num_envs, observation_space)
next_obs = torch.Tensor(next_obs)
next_done = torch.zeros(args.num_envs)

args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size    


def train():
    for iteration in range(1, args.num_iterations + 1):

        # collect data
        for step in range(args.num_steps):
            obs[step] = next_obs
            dones[step] = next_done

            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()  # (num_envs, 1)
            actions[step] = action  # (num_envs)
            logprobs[step] = logprob  # (num_envs)

            # execute the game
            next_obs, reward, terminations, truncations, infos = envs.step(action.numpy())
            next_done = np.logical_or(terminations, truncations)
            next_obs, next_done = torch.Tensor(next_obs), torch.Tensor(next_done)
            rewards[step] = torch.tensor(reward).view(-1)

        # compute advantages and rewards-to-go
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values

        # flatten the batch
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # Optimizing the policy and value network
        b_inds = np.arange(args.batch_size)
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start: end]

                _, newlogprob, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    approx_kl = ((ratio - 1) - logratio).mean()

                mb_advantages = b_advantages[mb_inds]
                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                v_clipped = b_values[mb_inds] + torch.clamp(
                    newvalue - b_values[mb_inds],
                    -args.clip_coef,
                    args.clip_coef)
                v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean()

                loss = pg_loss + v_loss * args.vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

                if epoch == args.update_epochs - 1 and iteration % 50 == 0:
                    print(iteration, loss.item())


        envs.close()

if __name__ == "__main__":
    train()

posted @ 2025-07-17 19:43  jack-chen666  阅读(165)  评论(0)    收藏  举报