"""My SAC continuous demo"""
import argparse
import copy
import os
import random
import gym
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions import Normal
def parse_args() -> argparse.Namespace:
"""Parse arguments."""
parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--seed", type=int, help="Fix random seed", default=0)
parser.add_argument(
"--log_path", type=str, help="Model path", default="./training_log/"
)
parser.add_argument(
"--device", type=str, help="Run on which device", default="cuda"
)
parser.add_argument(
"--max_buffer_size", type=int, help="Max buffer size", default=1e7
)
parser.add_argument(
"--min_buffer_size", type=int, help="Min buffer size", default=5e4
)
parser.add_argument("--hidden_width", type=int, help="Hidden width", default=256)
parser.add_argument("--gamma", type=float, help="gamma", default=0.99)
parser.add_argument("--tau", type=float, help="tau", default=0.005)
parser.add_argument(
"--learning_rate", type=float, help="Learning rate", default=1e-3
)
parser.add_argument(
"--max_train_steps", type=int, help="Max training steps", default=1e7
)
parser.add_argument("--batch_size", type=int, help="Batch size", default=256)
parser.add_argument(
"--evaluate_freqency", type=int, help="Evaluate freqency", default=1e6
)
return parser.parse_args()
def set_seed(seed: int) -> None:
"""Set seed for reproducibility."""
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
class ReplayBuffer:
"""Replay buffer for storing transitions."""
def __init__(self, state_dim: int, action_dim: int) -> None:
self.max_size = int(args.max_buffer_size)
self.count = 0
self.size = 0
self.state = np.zeros((self.max_size, state_dim))
self.action = np.zeros((self.max_size, action_dim))
self.reward = np.zeros((self.max_size, 1))
self.next_state = np.zeros((self.max_size, state_dim))
self.done = np.zeros((self.max_size, 1))
def store(
self,
state: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
next_state: np.ndarray,
done: np.ndarray,
) -> None:
"""Store a transition in the replay buffer."""
self.state[self.count] = state
self.action[self.count] = action
self.reward[self.count] = reward
self.next_state[self.count] = next_state
self.done[self.count] = done
self.count = (self.count + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size: int) -> tuple:
"""Sample a batch of transitions."""
index = np.random.choice(self.size, size=batch_size)
batch_state = torch.tensor(self.state[index], dtype=torch.float).to(args.device)
batch_action = torch.tensor(self.action[index], dtype=torch.float).to(
args.device
)
batch_reward = torch.tensor(self.reward[index], dtype=torch.float).to(
args.device
)
batch_next_state = torch.tensor(self.next_state[index], dtype=torch.float).to(
args.device
)
batch_done = torch.tensor(self.done[index], dtype=torch.float).to(args.device)
return batch_state, batch_action, batch_reward, batch_next_state, batch_done
class Actor(nn.Module):
"""Actor network."""
def __init__(
self, state_dim: int, action_dim: int, hidden_width: int, max_action: float
) -> None:
super().__init__()
self.max_action = max_action
self.in_layer = nn.Sequential(
nn.Linear(state_dim, hidden_width),
nn.ReLU(inplace=True),
nn.LayerNorm(hidden_width),
)
self.res_layer = nn.Sequential(
nn.Linear(hidden_width, hidden_width),
nn.ReLU(inplace=True),
nn.LayerNorm(hidden_width),
nn.Linear(hidden_width, hidden_width),
)
self.out_layer = nn.Sequential(
nn.Linear(hidden_width, hidden_width),
nn.ReLU(inplace=True),
nn.LayerNorm(hidden_width),
)
self.mean_layer = nn.Sequential(nn.ReLU(), nn.Linear(hidden_width, action_dim))
self.log_std_layer = nn.Sequential(
nn.ReLU(inplace=True), nn.Linear(hidden_width, action_dim)
)
def forward(self, x: torch.Tensor, deterministic: bool = False) -> tuple:
"""Forward pass."""
x = self.in_layer(x)
x = self.out_layer(x + self.res_layer(x))
mean = self.mean_layer(x)
log_std = self.log_std_layer(x)
log_std = torch.clamp(log_std, -20, 2)
std = torch.exp(log_std)
dist = Normal(mean, std)
if deterministic:
action = mean
else:
action = dist.rsample()
log_pi = dist.log_prob(action).sum(dim=1, keepdim=True)
log_pi -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum(
dim=1, keepdim=True
)
action = self.max_action * torch.tanh(action)
return action, log_pi
class Critic(nn.Module):
"""Critic network."""
def __init__(self, state_dim: int, action_dim: int, hidden_width: int) -> None:
super().__init__()
self.in_layer1 = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_width),
nn.ReLU(inplace=True),
nn.LayerNorm(hidden_width),
)
self.res_layer1 = nn.Sequential(
nn.Linear(hidden_width, hidden_width),
nn.ReLU(inplace=True),
nn.LayerNorm(hidden_width),
nn.Linear(hidden_width, hidden_width),
)
self.out_layer1 = nn.Sequential(
nn.ReLU(inplace=True), nn.Linear(hidden_width, 1)
)
self.in_layer2 = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_width),
nn.ReLU(inplace=True),
nn.LayerNorm(hidden_width),
)
self.res_layer2 = nn.Sequential(
nn.Linear(hidden_width, hidden_width),
nn.ReLU(inplace=True),
nn.LayerNorm(hidden_width),
nn.Linear(hidden_width, hidden_width),
)
self.out_layer2 = nn.Sequential(
nn.ReLU(inplace=True), nn.Linear(hidden_width, 1)
)
def forward(self, state: torch.Tensor, action: torch.Tensor) -> tuple:
"""Forward pass."""
state_action = torch.cat([state, action], 1)
q1 = self.in_layer1(state_action)
q1 = self.out_layer1(q1 + self.res_layer1(q1))
q2 = self.in_layer2(state_action)
q2 = self.out_layer2(q2 + self.res_layer2(q2))
return q1, q2
class SACContinuous:
"""Soft Actor-Critic for continuous action space."""
def __init__(self, state_dim: int, action_dim: int, max_action: float) -> None:
self.gamma = args.gamma
self.tau = args.tau
self.batch_size = args.batch_size
self.learning_rate = args.learning_rate
self.hidden_width = args.hidden_width
self.max_action = max_action
self.target_entropy = -np.log(2 * action_dim)
self.log_alpha = torch.tensor(1.0).to(args.device)
self.log_alpha.requires_grad = True
self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=self.learning_rate)
self.actor = Actor(state_dim, action_dim, self.hidden_width, max_action).to(
args.device
)
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=self.learning_rate
)
self.critic = Critic(state_dim, action_dim, self.hidden_width).to(args.device)
self.critic_target = copy.deepcopy(self.critic).to(args.device)
self.critic_optimizer = torch.optim.Adam(
self.critic.parameters(), lr=self.learning_rate
)
def choose_action(
self, state: np.ndarray, deterministic: bool = False
) -> np.ndarray:
"""Choose action."""
state = torch.unsqueeze(torch.tensor(state, dtype=torch.float), 0).to(
args.device
)
action, _ = self.actor(state, deterministic)
return action.data.cpu().numpy().flatten()
def learn(self, relay_buffer: ReplayBuffer) -> None:
"""Learn."""
batch_state, batch_action, batch_reward, batch_next_state, batch_done = (
relay_buffer.sample(self.batch_size)
)
batch_next_action, log_pi_ = self.actor(batch_next_state)
target_q1, target_q2 = self.critic_target(batch_next_state, batch_next_action)
target_q = batch_reward + self.gamma * (1 - batch_done) * (
torch.min(target_q1, target_q2) - self.log_alpha.exp() * log_pi_
)
current_q1, current_q2 = self.critic(batch_state, batch_action)
critic_loss = F.mse_loss(current_q1, target_q.detach()) + F.mse_loss(
current_q2, target_q.detach()
)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
for params in self.critic.parameters():
params.requires_grad = False
action, log_pi = self.actor(batch_state)
q1, q2 = self.critic(batch_state, action)
q = torch.min(q1, q2)
actor_loss = (self.log_alpha.exp() * log_pi - q).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
for params in self.critic.parameters():
params.requires_grad = True
alpha_loss = -(
self.log_alpha.exp() * (log_pi + self.target_entropy).detach()
).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
for param, target_param in zip(
self.critic.parameters(), self.critic_target.parameters()
):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data
)
def evaluate_policy(env, agent: SACContinuous) -> float:
"""Evaluate the policy."""
state = env.reset()[0]
done = False
episode_reward = 0
action_num = 0
agent.actor.eval()
while not done:
action = agent.choose_action(state, deterministic=True)
next_statue, reward, done, _, _ = env.step(action)
episode_reward += reward
state = next_statue
action_num += 1
if action_num >= 1e6:
print("action_num too large.")
break
if episode_reward <= -1e6:
print("episode_reward too small.")
break
return episode_reward
def training() -> None:
"""My demo training function."""
env_name = "Pendulum-v1"
env = gym.make(env_name)
env_evaluate = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
agent = SACContinuous(state_dim, action_dim, max_action)
replay_buffer = ReplayBuffer(state_dim, action_dim)
evaluate_num = 0
total_steps = 0
while total_steps < args.max_train_steps:
state = env.reset()[0]
episode_steps = 0
done = False
while not done:
episode_steps += 1
action = agent.choose_action(state)
next_state, reward, done, _, _ = env.step(action)
replay_buffer.store(state, action, reward, next_state, done)
state = next_state
if total_steps >= args.min_buffer_size:
agent.learn(replay_buffer)
if (total_steps + 1) % args.evaluate_freqency == 0:
evaluate_num += 1
evaluate_reward = evaluate_policy(env_evaluate, agent)
print(
f"evaluate_num: {evaluate_num} \t evaluate_reward: {evaluate_reward}"
)
total_steps += 1
if total_steps >= args.max_train_steps:
break
env.close()
torch.save(agent.actor.state_dict(), f"{args.log_path}/trained_model.pth")
def testing() -> None:
"""My demo testing function."""
env_name = "Pendulum-v1"
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
agent = SACContinuous(state_dim, action_dim, max_action)
agent.actor.load_state_dict(torch.load(f"{args.log_path}/trained_model.pth"))
agent.actor.eval()
state = env.reset()[0]
total_rewards = 0
with torch.no_grad():
for _ in range(10000):
env.render()
action = agent.choose_action(state, deterministic=True)
new_state, reward, _, _, _ = env.step(action)
total_rewards += reward
state = new_state
env.close()
print(f"SAC actor scores: {total_rewards}")
if __name__ == "__main__":
args = parse_args()
set_seed(args.seed)
training()
testing()