大模型- 强化学习-FlappyBird PPO实现-93

from dataclasses import dataclass
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
import flappy_bird_gymnasium


@dataclass 
class Args:
    env_id: str = "FlappyBird-v0"
    total_timesteps: int = 1000000
    learning_rate: float = 2.5e-4
    num_envs: int = 4
    num_steps: int = 128
    num_minibatches: int = 4
    update_epochs: int = 4
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_coef: float = 0.2
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    batch_size: int = 0
    minibatch_size: int = 0
    num_iterations: int = 0

args = Args()

# Create vectorized environment
envs = gym.vector.SyncVectorEnv([lambda: gym.make(args.env_id) for _ in range(args.num_envs)])
print("observation space:", envs.single_observation_space)
print("action space:", envs.single_action_space)

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Agent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
        )

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), self.critic(x)

agent = Agent(envs=envs)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

# Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape)
logprobs = torch.zeros((args.num_steps, args.num_envs))
rewards = torch.zeros((args.num_steps, args.num_envs))
dones = torch.zeros((args.num_steps, args.num_envs))
values = torch.zeros((args.num_steps, args.num_envs))

# Initialize game
next_obs, _ = envs.reset(seed=1)
next_obs = torch.Tensor(next_obs)
next_done = torch.zeros(args.num_envs)

args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size


# Initialize wandb
import wandb
wandb.login(key="605a915eb6ba670f5bb9d717b37cc720b857755a")
wandb.init(
    project="flappy-bird-ppo",
    config={
        "env_id": args.env_id,
        "total_timesteps": args.total_timesteps,
        "num_envs": args.num_envs,
        "num_steps": args.num_steps,
        "learning_rate": args.learning_rate,
        "num_minibatches": args.num_minibatches,
        "update_epochs": args.update_epochs,
        "gamma": args.gamma,
        "gae_lambda": args.gae_lambda,
        "clip_coef": args.clip_coef,
        "vf_coef": args.vf_coef,
        "max_grad_norm": args.max_grad_norm
    }
)

def train():
    global next_obs, next_done
    for iteration in range(1, args.num_iterations + 1):
        # collect data
        for step in range(args.num_steps):
            obs[step] = next_obs
            dones[step] = next_done

            # action logic
            with torch.no_grad():
                action, logprob, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # execute game step
            next_obs, reward, terminations, truncations, infos = envs.step(action.numpy())
            next_done = np.logical_or(terminations, truncations)
            next_obs, next_done = torch.Tensor(next_obs), torch.Tensor(next_done)
            rewards[step] = torch.tensor(reward).view(-1)

        # compute advantages
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values

        # flatten batch
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # optimize policy and value networks
        b_inds = np.arange(args.batch_size)
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                # Policy loss
                mb_advantages = b_advantages[mb_inds]
                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                v_clipped = b_values[mb_inds] + torch.clamp(
                    newvalue - b_values[mb_inds],
                    -args.clip_coef,
                    args.clip_coef)
                v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean()

                loss = pg_loss + v_loss * args.vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

        if iteration % 10 == 0:
            print(f"Iteration {iteration}, Loss: {loss.item():.4f}")
            wandb.log({
                "loss": loss.item(),
                "policy_loss": pg_loss.item(),
                "value_loss": v_loss.item(),
                "iteration": iteration
            })

    torch.save(agent.state_dict(), "flappy_bird_agent.pth")
    envs.close()

def eval_model():
    agent.load_state_dict(torch.load("flappy_bird_agent.pth"))
    agent.eval()
    
    env = gym.make(args.env_id, render_mode='human')
    
    for i in range(5):
        obs, _ = env.reset()
        total_reward = 0
        while True:
            env.render()
            
            with torch.no_grad():
                obs_tensor = torch.from_numpy(obs).float().unsqueeze(0)
                action, _, _ = agent.get_action_and_value(obs_tensor)
                
            obs, reward, terminated, truncated, _ = env.step(action.item())
            total_reward += reward
            
            if terminated or truncated:
                print(f"Episode {i+1} finished with reward {total_reward}")
                break
                
    env.close()

if __name__ == "__main__":
    train()
    eval_model()

Iteration 1950, Loss: 0.0331
Episode 1 finished with reward 77.19999999999978
Episode 2 finished with reward 315.0000000000033
Episode 1 finished with reward 77.19999999999978
Episode 2 finished with reward 315.0000000000033
Episode 3 finished with reward 231.79999999999086
Episode 4 finished with reward 87.69999999999905
Episode 5 finished with reward 202.69999999999249

img_v3_02oa_21e44121-a6f3-434e-b418-77c14da4f6eg

posted @ 2025-07-18 17:11  jack-chen666  阅读(26)  评论(0)    收藏  举报