八、Actor-Critic 算法

1 简介

我们之前介绍了基于价值函数的 DQN 算法和基于策略函数的 REINFORCE 算法,接下来将两者结合,既学习价值函数又学习策略函数的 Actor-Critic 算法。需要明确的是,Actor-Critic 算法本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习。

2 Actor-Critic算法

REINFORCE 算法使用 Monte Carlo 方法来估计 q(s,a) 来指导策略的更新,而 Actor-Critic 算法则考虑拟合一个价值函数来指导策略的学习。在策略梯度中,我们将梯度写作下面这个更加一般的形式:

\[g=E[\sum_{t=0}^Tψ_t∇_θlogπ_θ(a_t|s_t)] \]

其中,\(ψ_t\)可以有很多形式:

\[1.\sum_{t'=0}^Tγ^{t'}r_{t'}:轨迹的总回报; \]

\[2.\sum_{t'=t}^Tγ^{t'-t}r_{t'}:动作 a_t 之后的回报; \]

\[3.\sum_{t'=t}^Tγ^{t'-t}r_{t'}-b(s_t):基准线版本的改进; \]

\[4.Q^{π_Θ}(s_t,a_t):动作价值函数; \]

\[5.A^{π_Θ}(s_t,a_t):优势函数; \]

\[6.r_{t+1}+γq(s_{t+1},a_{t+1},ω_t)-q(s_t,a_t,ω_t):时序差分残差。 \]

本章将着重介绍形式(6),即通过时序差分残差来指导策略梯度进行学习。事实上,用 q 值本质上也是用奖励来进行指导,但是用神经网络进行估计的方法可以减小方差、提高鲁棒性。除此之外,REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新,这同时也要求任务具有有限的步数,而 Actor-Critic 算法则可以在每一步之后都进行更新,并且不对任务的步数做限制。
我们将 Actor-Critic 分为两个部分:Actor(策略网络)和 Critic(价值网络),如图所示。

  • Actor 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。
  • Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。
    image
    Actor 的更新采用策略梯度的原则,那 Critic 如何更新呢?我们将 Critic 价值网络表示为 \(V_ω\),参数为 \(ω\)。于是,我们可以采取时序差分残差的学习方式,对于单个数据定义如下价值函数的损失函数:

\[\mathcal{L}(ω)=\frac{1}{2}(r+γV_ω(s_{t+1})-V_ω(s_t))^2 \]

与 DQN 中一样,我们采取类似于目标网络的方法,将上式中作为时序差分目标,不会产生梯度来更新价值函数。因此,价值函数的梯度为:

\[∇_ω\mathcal{L}(ω)=-(r+γV_ω(s_{t+1})-V_ω(s_t))∇_ωV_ω(s_t) \]

然后使用梯度下降方法来更新 Critic 价值网络参数即可。

3 算法

rl_utils.py:

from tqdm import tqdm
import numpy as np
import torch
import collections
import random

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity) 

    def add(self, state, action, reward, next_state, done): 
        self.buffer.append((state, action, reward, next_state, done)) 

    def sample(self, batch_size): 
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done 

    def size(self): 
        return len(self.buffer)

def moving_average(a, window_size):
    cumulative_sum = np.cumsum(np.insert(a, 0, 0)) 
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    r = np.arange(1, window_size-1, 2)
    begin = np.cumsum(a[:window_size-1])[::2] / r
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    return np.concatenate((begin, middle, end))

def train_on_policy_agent(env, agent, num_episodes):
    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes/10)):
                episode_return = 0
                transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
                state = env.reset()
                done = False
                while not done:
                    action = agent.take_action(state)
                    next_state, reward, terminated, truncated, _ = env.step(action)
                    if terminated or truncated:
                        done = True
                    transition_dict['states'].append(state)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(next_state)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    state = next_state
                    episode_return += reward
                return_list.append(episode_return)
                agent.update(transition_dict)
                if (i_episode+1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list

def train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size):
    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes/10)):
                episode_return = 0
                state = env.reset()
                done = False
                while not done:
                    action = agent.take_action(state)
                    next_state, reward, done, _ = env.step(action)
                    replay_buffer.add(state, action, reward, next_state, done)
                    state = next_state
                    episode_return += reward
                    if replay_buffer.size() > minimal_size:
                        b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                        transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}
                        agent.update(transition_dict)
                return_list.append(episode_return)
                if (i_episode+1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list

def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().numpy()
    advantage_list = []
    advantage = 0.0
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list, dtype=torch.float)
import gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utils

class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)

class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class ActorCritic:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 gamma, device):
        # 策略网络
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        # 价值网络
        self.critic = ValueNet(state_dim, hidden_dim).to(device)  
        # 策略网络优化器
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        # 价值网络优化器
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr)  
        self.gamma = gamma
        self.device = device

    def take_action(self, state):
        state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(np.array(transition_dict['states']),
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(np.array(transition_dict['rewards']),
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(np.array(transition_dict['next_states']),
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(np.array(transition_dict['dones']),
                             dtype=torch.float).view(-1, 1).to(self.device)

        # 时序差分目标
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
        td_delta = td_target - self.critic(states)  # 时序差分误差
        log_probs = torch.log(self.actor(states).gather(1, actions))
        actor_loss = torch.mean(-log_probs * td_delta.detach())
        # 均方误差损失函数
        critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()  # 计算策略网络的梯度
        critic_loss.backward()  # 计算价值网络的梯度
        self.actor_optimizer.step()  # 更新策略网络的参数
        self.critic_optimizer.step()  # 更新价值网络的参数

actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

env_name = 'CartPole-v1'
env = gym.make(env_name, new_step_api=True)
env.reset(seed=0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                    gamma, device)

return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Actor-Critic on {}'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Actor-Critic on {}'.format(env_name))
plt.show()
Iteration 0: 100%|██████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 153.71it/s, episode=100, return=20.700]
Iteration 1: 100%|███████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 92.93it/s, episode=200, return=44.500]
Iteration 2: 100%|███████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 52.09it/s, episode=300, return=79.100]
Iteration 3: 100%|██████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 22.97it/s, episode=400, return=202.300] 
Iteration 4: 100%|██████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 16.64it/s, episode=500, return=354.100] 
Iteration 5: 100%|██████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 12.13it/s, episode=600, return=385.800] 
Iteration 6: 100%|██████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.20it/s, episode=700, return=435.000] 
Iteration 7: 100%|██████████████████████████████████████████████████████████████████████████████████| 100/100 [00:11<00:00,  8.70it/s, episode=800, return=490.500] 
Iteration 8: 100%|██████████████████████████████████████████████████████████████████████████████████| 100/100 [00:12<00:00,  7.84it/s, episode=900, return=500.000] 
Iteration 9: 100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [00:12<00:00,  7.87it/s, episode=1000, return=500.000]

在 CartPole-v1 环境中,满分就是 500 分。和 REINFORCE 相似,接下来我们绘制训练过程中每一条轨迹的回报变化图以及其经过平滑处理的版本。
image
image
根据实验结果我们可以发现,Actor-Critic 算法很快便能收敛到最优策略,并且训练过程非常稳定,抖动情况相比 REINFORCE 算法有了明显的改进,这说明价值函数的引入减小了方差。

参考资料

https://hrl.boyuai.com/chapter/2/actor-critic算法
https://www.bilibili.com/video/BV1sd4y167NS?p=52&vd_source=f7563459deb4ecb3add61713c7d5d111

posted @ 2024-05-05 11:03  Hell0er  阅读(355)  评论(0)    收藏  举报