强化学习(六)-Sarsa

一、概念

1、强化学习有两个问题:预测和控制

时序差分可以解决强化学习的预测问题

Sarsa是时序差分在线控制算法,解决控制问题

2、控制算法有两种:在线和离线

在线控制,使用一个策略(ϵ-贪婪法),来更新价值函数、进行动作选择

离线控制,使用两个策略,分别用于更新价值函数、进行动作选择

3、Sarsa的含义:S(state),A(action),R(reward),S‘(下一个state),A’(下一个action)

4、Q Learning算法

估计了下一步的action,但是并不一定会选择这个估计的action,而是选择一个值最大的action

而Sarsa估计了下一步的action,就会选择该action,是Q learning算法的改进

5、优点:不需要状态模型,也不需要完整的状态序列,在传统的强化学习中应用广泛

6、缺点:由于Q表的限制,无法求解太复杂的问题

 

二、Sarsa计算公式

1、Q(s1,a2)估计:Q(s1,a2)

2、Q(s1,a2)现实:r+γ*Q(s2,a2)

3、计算差距=现实-估计=r+γ*Q(s2,a2)-Q(s1,a2)

4、得到:新的Q(s1,a2)+=a*差距,即Q(s1,a2)=Q(s1,a2)+a*(r+γ*Q(s2,a2)-Q(s1,a2))

image

 

三、Sarsa算法

1、输入:迭代轮次T,状态集S,动作集A,步长a,衰减因子γ,探索率ϵ

2、输出:所有状态和动作对应的Q值

 

四、Sarsa(λ)

1、单步更新:走了某一步,获得了反馈,就对该步进行更新

2、回合更新:走完了一个回合,才会对该回合,所走的每一步都进行更新,每一步的更新力度一样大

3、单步更新和回合更新各有优缺点,所以才有了λ,介于单步和回合中间,离奖励越近,更新力度越大

4、0<λ<1,是脚步衰减值;而计算公式里的γ,是奖励衰减值

5、计算公式

image

 6、代码实现里多了一张eligibility_trace表,初始值是q表的复制,用于标记脚步衰减

 

五、代码

1、算法更新

from RL_brain import QLearningTable
from maze_env import Maze
from sarsa.RL_brain import SarsaTable, SarsaLambdaTable


def update():
    for episode in range(100):
        observation = env.reset()

        action = RL.choose_action(str(observation))

        while True:
            env.render()

            observation_, reward, done = env.step(action)

            action_ = RL.choose_action(str(observation))

            RL.learn(str(observation), action, reward, str(observation_), action_)

            observation = observation_
            action = action_

            if done:
                break

    print('game over')
    env.destroy()


if __name__ == '__main__':
    env = Maze()
    RL = SarsaLambdaTable(actions=list(range(env.n_actions)))

    env.after(100, update)
    env.mainloop()

 

2、思维决策

import numpy as np
import pandas as pd


class RL(object):
    def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        self.actions = action_space
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy
        self.q_table = pd.DataFrame(columns=self.actions)

    def choose_action(self, observation):
        self.check_state_exist(observation)
        if np.random.uniform() < self.epsilon:
            state_action = self.q_table.loc[observation, :]
            state_action = state_action.reindex(np.random.permutation(state_action.index))
            action = state_action.argmax()
        else:
            action = np.random.choice(self.actions)
        return action

    def learn(self, *args):
        pass

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            self.q_table = pd.concat(
                [self.q_table,
                 pd.Series(
                     [0] * len(self.actions),
                     index=self.q_table.columns,
                     name=state
                 ).to_frame().T]
            )


class QLearningTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            # 最大值
            q_target = r + self.gamma * self.q_table.iloc[s_, :].max()
        else:
            q_target = r
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)


class SarsaTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            # 确定a_
            q_target = r + self.gamma * self.q_table.loc[s_, a_]
        else:
            q_target = r
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)


class SarsaLambdaTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.1):
        super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
        self.lambda_ = trace_decay
        self.eligibility_trace = self.q_table.copy()

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            # 确定a_
            q_target = r + self.gamma * self.q_table.loc[s_, a_]
        else:
            q_target = r

        error = q_target - q_predict

        # 方法一
        self.eligibility_trace.loc[s, a] += 1

        # 方法二
        self.eligibility_trace.loc[s, :] *= 0
        self.eligibility_trace.loc[s, a] = 1

        self.q_table += self.lr * (q_target - q_predict) * self.eligibility_trace

        self.eligibility_trace *= self.gamma * self.lambda_

    # 两张表都要加上state
    def check_state_exist(self, state):
        if state not in self.q_table.index:
            self.q_table = pd.concat(
                [self.q_table,
                 pd.Series(
                     [0] * len(self.actions),
                     index=self.q_table.columns,
                     name=state
                 ).to_frame().T]
            )

            self.eligibility_trace = pd.concat(
                [self.eligibility_trace,
                 pd.Series(
                     [0] * len(self.actions),
                     index=self.q_table.columns,
                     name=state
                 ).to_frame().T]
            )

 

 

 

参考:

https://www.bilibili.com/video/BV13W411Y75P/?share_source=copy_web&vd_source=dca1040c69df2394f24af4c691047ddd

https://www.cnblogs.com/pinard/p/9614290.html

 

posted @ 2025-08-18 15:50  牧云文仔  阅读(25)  评论(0)    收藏  举报