Deep Learning专栏--强化学习之Q-Learning与DQN(2)

 

在上一篇文章中介绍了MDP与Bellman方程,MDP可以对强化学习的问题进行建模,Bellman提供了计算价值函数的迭代公式。但在实际问题中,我们往往无法准确获知MDP过程中的转移概率$P$,因此无法直接将解决 MDP 问题的经典思路 value iteration 和 policy iteration 应用到解决强化学习的问题上。为了将转移概率以逼近实际情况的方式计算出来,基于value iteration的Q-Learning算法应运而生,它通过在迭代过程中不断更新Q-table的方式来近似转移概率矩阵$P$。此外,Sarsa还可以online的形式学习,区别在于与Q-Learning的迭代过程不同。最后,本文还将介绍DQN (Deep Q-Learning Network),当转移矩阵P (Q-table) 过大时计算困难甚至无法计算,而DQN其利用Deep网络结构拟合Q-table,使得Q-Learning框架具备了解决状态无限(动作仍旧有限)的强化学习问题

 

1. Q-Learning

Q-Learning 是一个强化学习中一个很经典的算法,其出发点很简单,就是用一张表存储在各个状态下执行各种动作能够带来的 reward,如下表表示了有两个状态 $s1$,$s2$,每个状态下有两个动作 $a1$,$a2$, 表格里面的值表示 reward

reward action 1 action 2
state 1 -1 2
state 2 -5 2

这个表就是 Q-Table,里面的每个值定义为$ Q(s,a) $, 表示在状态$ s $下执行动作$ a $ 所获取的reward,那么选择的时候可以采用一个贪婪的做法,即选择价值最大的那个动作去执行。

Q-Table通过随机初始化来生成初始表格,然后通过不断执行动作获取环境的反馈并通过算法更新 Q-Table。下面重点讲如何通过算法更新 Q-Table。

当我们处于某个状态$ s $时,根据 Q-Table 的值选择的动作$ a $, 那么从表格获取的 reward 为 $ Q(s,a) $,此时的 reward 并不是我们真正的获取的 reward,而是预期获取的 reward,那么真正的 reward 在哪?我们知道执行了动作$ a $并转移到了下一个状态$ s′ $时,能够获取一个即时的 reward(记为$ r $), 但是除了即时的 reward,还要考虑所转移到的状态 $ s′ $ 对未来期望的reward,因此真实的 reward (记为$ Q′(s,a) $)由两部分组成:即时的 reward 和未来期望的 reward,且未来的 reward 往往是不确定的,因此需要加个折扣因子$ \gamma $,则真实的 reward 表示如下:

$$ Q’(s,a) = r + \gamma\max_{a’}Q(s’,a’) $$

$ \gamma $的值一般设置为 0 到 1 之间,设为0时表示只关心即时回报,设为 1 时表示未来的期望回报跟即时回报一样重要。

有了真实的 reward 和预期获取的 reward,可以很自然地想到用 supervised learning那一套,求两者的误差然后进行更新,在 Q-learning 中也是这么干的,更新的值则是原来的$ Q(s, a) $,更新规则如下:

$$ Q(s, a) = Q(s, a) + \alpha(Q’(s, a) - Q(s,a)) $$

更新规则跟梯度下降非常相似,这里的$ \alpha $可理解为学习率。更新规则也很简单,可是这一类采用了贪心思想的算法往往都会有这么一个问题:算法是否能够收敛,是收敛到局部最优还是全局最优?

关于收敛性,可以参考 Convergence of Q-learning: a simple proof,这个文档 证明了这个算法能够收敛,且根据知乎上这个问题 RL两大类算法的本质区别?(Policy Gradient 和 Q-Learning),原始的 Q-Learning 理论上能够收敛到最优解,但是通过 Q 函数近似 Q-Table 的方法则未必能够收敛到最优解(如DQN)

除此之外, Q-Learning 中还存在着探索与利用(Exploration and Exploition)的问题, 大致的意思就是不要每次都遵循着当前看起来是最好的方案,而是会选择一些当前看起来不是最优的策略,这样也许会更快探索出更优的策略。Exploration and Exploition 的做法很多,Q-Learning 采用了最简单的$ \epsilon-greedy $, 就是每次有$ \epsilon $的概率是选择当前 Q-Table 里面值最大的action的,$ 1 - \epsilon $的概率是随机选择策略的。

Q-Learning 算法的流程如下,图片摘自这里

 

上面的流程中的 Q 现实 就是上面说的 $ Q′(s,a) $, Q 估计就是上面说的$ Q(s,a) $。

下面的 python 代码演示了更新通过 Q-Table 的算法, 参考了这个 repo 上的代码,初始化主要是设定一些参数,并建立 Q-Table, choose_action 是根据当前的状态 observation,并以 $ \epsilon-greedy $ 的策略选择当前的动作; learn 则是更新当前的 Q-Table,check_state_exist 则是检查当前的状态是否已经存在 Q-Table 中,若不存在要在 Q-Table 中创建相应的行。

 1 import numpy as np
 2 import pandas as pd
 3 
 4 class QTable:
 5     def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
 6         self.actions = actions  # a list
 7         self.lr = learning_rate
 8         self.gamma = reward_decay
 9         self.epsilon = e_greedy
10         self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
11 
12     def choose_action(self, observation):
13         self.check_state_exist(observation)
14         # action selection
15         if np.random.uniform() < self.epsilon:
16             # choose best action
17             state_action = self.q_table.ix[observation, :]
18             state_action = state_action.reindex(np.random.permutation(state_action.index))     # some actions have same value
19             action = state_action.argmax()
20         else:
21             # choose random action
22             action = np.random.choice(self.actions)
23         return action
24 
25     def learn(self, s, a, r, s_):
26         self.check_state_exist(s_)
27         q_predict = self.q_table.ix[s, a]
28         if s_ != 'terminal':
29             q_target = r + self.gamma * self.q_table.ix[s_, :].max()  # next state is not terminal
30         else:
31             q_target = r  # next state is terminal
32         self.q_table.ix[s, a] += self.lr * (q_target - q_predict)  # update
33 
34     def check_state_exist(self, state):
35         if state not in self.q_table.index:
36             # append new state to q table
37             self.q_table = self.q_table.append(
38                 pd.Series(
39                     [0]*len(self.actions),
40                     index=self.q_table.columns,
41                     name=state,
42                 )
43             )
View Code

 

2. Sarsa

Sarsa 跟 Q-Learning 非常相似,也是基于 Q-Table 进行决策的。不同点在于决定下一状态所执行的动作的策略,Q-Learning 在当前状态更新 Q-Table 时会用到下一状态Q值最大的那个动作,但是下一状态未必就会选择那个动作;但是 Sarsa 会在当前状态先决定下一状态要执行的动作,并且用下一状态要执行的动作的 Q 值来更新当前状态的 Q 值;说的好像很绕,但是看一下下面的流程便可知道这两者的具体差异了,图片摘自这里

那么,这两者的区别在哪里呢?这篇文章里面是这样讲的

This means that SARSA takes into account the control policy by which the agent is moving, and incorporates that into its update of action values, where Q-learning simply assumes that an optimal policy is being followed.

简单来说就是 Sarsa 在执行action时会考虑到全局(如更新当前的 Q 值时会先确定下一步要走的动作), 而 Q-Learning 则显得更加的贪婪和”短视”, 每次都会选择当前利益最大的动作(不考虑 $ \epsilon-greedy $),而不考虑其他状态。

那么该如何选择,根据这个问题:When to choose SARSA vs. Q Learning,有如下结论

If your goal is to train an optimal agent in simulation, or in a low-cost and fast-iterating environment, then Q-learning is a good choice, due to the first point (learning optimal policy directly). If your agent learns online, and you care about rewards gained whilst learning, then SARSA may be a better choice.

简单来说就是如果要在线学习,同时兼顾 reward 和总体的策略(如不能太激进,agent 不能很快挂掉),那么选择 Sarsa;而如果没有在线的需求的话,可以通过 Q-Learning 线下模拟找到最好的 agent。所以也称 Sarsa 为on-policy,Q-Learning 为 off-policy。

 

3. DQN

我们前面提到的两种方法都以依赖于 Q-Table,但是其中存在的一个问题就是当 Q-Table 中的状态比较多,可能会导致整个 Q-Table 无法装下内存。因此,DQN 被提了出来,DQN 全称是 Deep Q Network,Deep 指的是通的是深度学习,其实就是通过神经网络来拟合整张 Q-Table。

DQN 能够解决状态无限,动作有限的问题;具体来说就是将当前状态作为输入,输出的是各个动作的 Q 值。以 Flappy Bird 这个游戏为例,输入的状态近乎是无限的(当前 bird 的位置和周围的水管的分布位置等),但是输出的动作只有两个(飞或者不飞)。实际上,已经有人通过 DQN 来玩这个游戏了,具体可参考这个 DeepLearningFlappyBird

所以在 DQN 中的核心问题在于如何训练整个神经网络,其实训练算法跟 Q-Learning 的训练算法非常相似,需要利用 Q 估计和 Q 现实的差值,然后进行反向传播。

这里放上提出 DQN 的原始论文 Playing atari with deep reinforcement learning 中的算法流程图

 

上面的算法跟 Q-Learning 最大的不同就是多了 Experience Replay 这个部分,实际上这个机制做的事情就是先进行反复的实验,并将这些实验步骤获取的 sample 存储在 memory 中,每一步就是一个 sample,每个sample是一个四元组,包括:当前的状态,当前状态的各种action的 Q 值,当前采取的action获得的即时回报,下一个状态的各种action的Q值。拿到这样一个 sample 后,就可以根据上面提到的 Q-Learning 更新算法来更新网络,只是这时候需要进行的是反向传播。

Experience Replay 机制的出发点是按照时间顺序所构造的样本之间是有关的(如上面的$ \phi(s_{t+1}) $ 会受到$ \phi(s_{t}) $的影响)、非静态的(highly correlated and non-stationary),这样会很容易导致训练的结果难以收敛。通过 Experience Replay 机制对存储下来的样本进行随机采样,在一定程度上能够去除这种相关性,进而更容易收敛。当然,这种方法也有弊端,就是训练的时候是 offline 的形式,无法做到 online 的形式。

除此之外,上面算法流程图中的 aciton-value function 就是一个深度神经网络,因为神经网络是被证明有万有逼近的能力的,也就是能够拟合任意一个函数;一个 episode 相当于 一个 epoch;同时也采用了$ \epsilon-greedy $策略。代码实现可参考上面 FlappyBird 的 DQN 实现。

上面提到的 DQN 是最原始的的网络,后面Deepmind 对其进行了多种改进,比如说 Nature DQN 增加了一种新机制 separate Target Network,就是计算上图的$ y_j $ 的时候不采用网络 $ Q $, 而是采用另外一个网络(也就是 Target Network) $ Q′ $, 原因是上面计算$ y_j $和 Q 估计都采用相同的网络$ Q $,这样使得$ Q $大的样本,$ y $也会大,这样模型震荡和发散可能性变大,其原因其实还是两者的关联性较大。而采用另外一个独立的网络使得训练震荡发散可能性降低,更加稳定。一般$ Q′ $会直接采用旧的$ Q $, 比如说 10 个 epoch 前的$ Q $.

除此之外,大幅度提升 DQN 玩 Atari 性能的主要就是 Double DQN,Prioritised Replay 还有 Dueling Network 三大方法;这里不详细展开,有兴趣可参考这两篇文章:DQN从入门到放弃6 DQN的各种改进深度强化学习(Deep Reinforcement Learning)入门:RL base & DQN-DDPG-A3C introduction

综上,本文介绍了强化学习中基于 value 的方法:包括 Q-Learning 以及跟 Q-Learning 非常相似的 Sarsa,同时介绍了通过 DQN 解决状态无限导致 Q-Table过大的问题。需要注意的是 DQN 只能解决动作有限的问题,对于动作无限或者说动作取值为连续值的情况,需要依赖于 policy gradient 这一类算法,而这一类算法也是目前更为推崇的算法,在下一章将介绍 Policy Gradient 以及结合 Policy Gradient 和 Q-Learning 的 Actor-Critic 方法。

 

 

参考:

1. 强化学习笔记(2)-从 Q-Learning 到 DQN

2. DQN从入门到放弃5 深度解读DQN算法

posted @ 2019-03-29 15:00  蓝鲸王子  阅读(2859)  评论(0编辑  收藏  举报