强化学习(六)-Sarsa
一、概念
1、强化学习有两个问题:预测和控制
时序差分可以解决强化学习的预测问题
Sarsa是时序差分在线控制算法,解决控制问题
2、控制算法有两种:在线和离线
在线控制,使用一个策略(ϵ-贪婪法),来更新价值函数、进行动作选择
离线控制,使用两个策略,分别用于更新价值函数、进行动作选择
3、Sarsa的含义:S(state),A(action),R(reward),S‘(下一个state),A’(下一个action)
4、Q Learning算法
估计了下一步的action,但是并不一定会选择这个估计的action,而是选择一个值最大的action
而Sarsa估计了下一步的action,就会选择该action,是Q learning算法的改进
5、优点:不需要状态模型,也不需要完整的状态序列,在传统的强化学习中应用广泛
6、缺点:由于Q表的限制,无法求解太复杂的问题
二、Sarsa计算公式
1、Q(s1,a2)估计:Q(s1,a2)
2、Q(s1,a2)现实:r+γ*Q(s2,a2)
3、计算差距=现实-估计=r+γ*Q(s2,a2)-Q(s1,a2)
4、得到:新的Q(s1,a2)+=a*差距,即Q(s1,a2)=Q(s1,a2)+a*(r+γ*Q(s2,a2)-Q(s1,a2))

三、Sarsa算法
1、输入:迭代轮次T,状态集S,动作集A,步长a,衰减因子γ,探索率ϵ
2、输出:所有状态和动作对应的Q值
四、Sarsa(λ)
1、单步更新:走了某一步,获得了反馈,就对该步进行更新
2、回合更新:走完了一个回合,才会对该回合,所走的每一步都进行更新,每一步的更新力度一样大
3、单步更新和回合更新各有优缺点,所以才有了λ,介于单步和回合中间,离奖励越近,更新力度越大
4、0<λ<1,是脚步衰减值;而计算公式里的γ,是奖励衰减值
5、计算公式

6、代码实现里多了一张eligibility_trace表,初始值是q表的复制,用于标记脚步衰减
五、代码
1、算法更新
from RL_brain import QLearningTable from maze_env import Maze from sarsa.RL_brain import SarsaTable, SarsaLambdaTable def update(): for episode in range(100): observation = env.reset() action = RL.choose_action(str(observation)) while True: env.render() observation_, reward, done = env.step(action) action_ = RL.choose_action(str(observation)) RL.learn(str(observation), action, reward, str(observation_), action_) observation = observation_ action = action_ if done: break print('game over') env.destroy() if __name__ == '__main__': env = Maze() RL = SarsaLambdaTable(actions=list(range(env.n_actions))) env.after(100, update) env.mainloop()
2、思维决策
import numpy as np import pandas as pd class RL(object): def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): self.actions = action_space self.lr = learning_rate self.gamma = reward_decay self.epsilon = e_greedy self.q_table = pd.DataFrame(columns=self.actions) def choose_action(self, observation): self.check_state_exist(observation) if np.random.uniform() < self.epsilon: state_action = self.q_table.loc[observation, :] state_action = state_action.reindex(np.random.permutation(state_action.index)) action = state_action.argmax() else: action = np.random.choice(self.actions) return action def learn(self, *args): pass def check_state_exist(self, state): if state not in self.q_table.index: self.q_table = pd.concat( [self.q_table, pd.Series( [0] * len(self.actions), index=self.q_table.columns, name=state ).to_frame().T] ) class QLearningTable(RL): def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) def learn(self, s, a, r, s_): self.check_state_exist(s_) q_predict = self.q_table.loc[s, a] if s_ != 'terminal': # 最大值 q_target = r + self.gamma * self.q_table.iloc[s_, :].max() else: q_target = r self.q_table.loc[s, a] += self.lr * (q_target - q_predict) class SarsaTable(RL): def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) def learn(self, s, a, r, s_, a_): self.check_state_exist(s_) q_predict = self.q_table.loc[s, a] if s_ != 'terminal': # 确定a_ q_target = r + self.gamma * self.q_table.loc[s_, a_] else: q_target = r self.q_table.loc[s, a] += self.lr * (q_target - q_predict) class SarsaLambdaTable(RL): def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.1): super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) self.lambda_ = trace_decay self.eligibility_trace = self.q_table.copy() def learn(self, s, a, r, s_, a_): self.check_state_exist(s_) q_predict = self.q_table.loc[s, a] if s_ != 'terminal': # 确定a_ q_target = r + self.gamma * self.q_table.loc[s_, a_] else: q_target = r error = q_target - q_predict # 方法一 self.eligibility_trace.loc[s, a] += 1 # 方法二 self.eligibility_trace.loc[s, :] *= 0 self.eligibility_trace.loc[s, a] = 1 self.q_table += self.lr * (q_target - q_predict) * self.eligibility_trace self.eligibility_trace *= self.gamma * self.lambda_ # 两张表都要加上state def check_state_exist(self, state): if state not in self.q_table.index: self.q_table = pd.concat( [self.q_table, pd.Series( [0] * len(self.actions), index=self.q_table.columns, name=state ).to_frame().T] ) self.eligibility_trace = pd.concat( [self.eligibility_trace, pd.Series( [0] * len(self.actions), index=self.q_table.columns, name=state ).to_frame().T] )
参考:
https://www.bilibili.com/video/BV13W411Y75P/?share_source=copy_web&vd_source=dca1040c69df2394f24af4c691047ddd
https://www.cnblogs.com/pinard/p/9614290.html
浙公网安备 33010602011771号