from dataclasses import dataclass
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
import flappy_bird_gymnasium
@dataclass
class Args:
env_id: str = "FlappyBird-v0"
total_timesteps: int = 1000000
learning_rate: float = 2.5e-4
num_envs: int = 4
num_steps: int = 128
num_minibatches: int = 4
update_epochs: int = 4
gamma: float = 0.99
gae_lambda: float = 0.95
clip_coef: float = 0.2
vf_coef: float = 0.5
max_grad_norm: float = 0.5
batch_size: int = 0
minibatch_size: int = 0
num_iterations: int = 0
args = Args()
# Create vectorized environment
envs = gym.vector.SyncVectorEnv([lambda: gym.make(args.env_id) for _ in range(args.num_envs)])
print("observation space:", envs.single_observation_space)
print("action space:", envs.single_action_space)
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
class Agent(nn.Module):
def __init__(self, envs):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 1), std=1.0),
)
self.actor = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
nn.Tanh(),
layer_init(nn.Linear(64, 64)),
nn.Tanh(),
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
)
def get_value(self, x):
return self.critic(x)
def get_action_and_value(self, x, action=None):
logits = self.actor(x)
probs = Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), self.critic(x)
agent = Agent(envs=envs)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
# Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape)
logprobs = torch.zeros((args.num_steps, args.num_envs))
rewards = torch.zeros((args.num_steps, args.num_envs))
dones = torch.zeros((args.num_steps, args.num_envs))
values = torch.zeros((args.num_steps, args.num_envs))
# Initialize game
next_obs, _ = envs.reset(seed=1)
next_obs = torch.Tensor(next_obs)
next_done = torch.zeros(args.num_envs)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size
# Initialize wandb
import wandb
wandb.login(key="605a915eb6ba670f5bb9d717b37cc720b857755a")
wandb.init(
project="flappy-bird-ppo",
config={
"env_id": args.env_id,
"total_timesteps": args.total_timesteps,
"num_envs": args.num_envs,
"num_steps": args.num_steps,
"learning_rate": args.learning_rate,
"num_minibatches": args.num_minibatches,
"update_epochs": args.update_epochs,
"gamma": args.gamma,
"gae_lambda": args.gae_lambda,
"clip_coef": args.clip_coef,
"vf_coef": args.vf_coef,
"max_grad_norm": args.max_grad_norm
}
)
def train():
global next_obs, next_done
for iteration in range(1, args.num_iterations + 1):
# collect data
for step in range(args.num_steps):
obs[step] = next_obs
dones[step] = next_done
# action logic
with torch.no_grad():
action, logprob, value = agent.get_action_and_value(next_obs)
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob
# execute game step
next_obs, reward, terminations, truncations, infos = envs.step(action.numpy())
next_done = np.logical_or(terminations, truncations)
next_obs, next_done = torch.Tensor(next_obs), torch.Tensor(next_done)
rewards[step] = torch.tensor(reward).view(-1)
# compute advantages
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
nextvalues = next_value
else:
nextnonterminal = 1.0 - dones[t + 1]
nextvalues = values[t + 1]
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values
# flatten batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_logprobs = logprobs.reshape(-1)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
# optimize policy and value networks
b_inds = np.arange(args.batch_size)
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
mb_inds = b_inds[start:end]
_, newlogprob, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()
# Policy loss
mb_advantages = b_advantages[mb_inds]
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss
newvalue = newvalue.view(-1)
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
v_clipped = b_values[mb_inds] + torch.clamp(
newvalue - b_values[mb_inds],
-args.clip_coef,
args.clip_coef)
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
loss = pg_loss + v_loss * args.vf_coef
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
if iteration % 10 == 0:
print(f"Iteration {iteration}, Loss: {loss.item():.4f}")
wandb.log({
"loss": loss.item(),
"policy_loss": pg_loss.item(),
"value_loss": v_loss.item(),
"iteration": iteration
})
torch.save(agent.state_dict(), "flappy_bird_agent.pth")
envs.close()
def eval_model():
agent.load_state_dict(torch.load("flappy_bird_agent.pth"))
agent.eval()
env = gym.make(args.env_id, render_mode='human')
for i in range(5):
obs, _ = env.reset()
total_reward = 0
while True:
env.render()
with torch.no_grad():
obs_tensor = torch.from_numpy(obs).float().unsqueeze(0)
action, _, _ = agent.get_action_and_value(obs_tensor)
obs, reward, terminated, truncated, _ = env.step(action.item())
total_reward += reward
if terminated or truncated:
print(f"Episode {i+1} finished with reward {total_reward}")
break
env.close()
if __name__ == "__main__":
train()
eval_model()
Iteration 1950, Loss: 0.0331
Episode 1 finished with reward 77.19999999999978
Episode 2 finished with reward 315.0000000000033
Episode 1 finished with reward 77.19999999999978
Episode 2 finished with reward 315.0000000000033
Episode 3 finished with reward 231.79999999999086
Episode 4 finished with reward 87.69999999999905
Episode 5 finished with reward 202.69999999999249
![img_v3_02oa_21e44121-a6f3-434e-b418-77c14da4f6eg]()