SARSA算法

SARSA(State - Action - Reward - State - Action)算法和Q - learning算法均为强化学习领域中用于学习最优策略的无模型算法,二者存在诸多区别,下面从多个方面进行详细阐述:

算法类型与策略特性

Q - learning:属于离线策略(off - policy)算法。这意味着在学习过程中,用于生成动作数据的行为策略(通常是 $\epsilon$-贪心策略,以一定概率 $\epsilon$ 随机选择动作进行探索,以 $1 - \epsilon$ 的概率选择Q值最大的动作进行利用)和用于更新Q值的目标策略(贪心策略,即总是选择Q值最大的动作)是不同的。这种特性使得Q - learning可以直接学习到最优策略,即使行为策略是随机探索的。

SARSA:是在线策略(on - policy)算法。在学习过程中,行为策略和目标策略是相同的,一般都采用 $\epsilon$-贪心策略。这使得SARSA更加保守,因为它会考虑到在探索过程中实际采取的动作带来的后果。

 

更新公式

Q - learning:其核心是对Q值函数进行迭代更新,更新公式为 $$Q(s,a) \leftarrow Q(s,a)+\alpha\left[r + \gamma\max_{a'}Q(s',a')-Q(s,a)\right]$$ 其中,$s$ 表示当前状态,$a$ 是当前采取的动作,$r$ 是执行动作 $a$ 后获得的即时奖励,$s'$ 是转移到的下一个状态,$\alpha$ 是学习率,控制每次更新的步长,$\gamma$ 是折扣因子,体现了对未来奖励的重视程度。$\max_{a'}Q(s',a')$ 表示在下一个状态 $s'$ 下,所有可能动作 $a'$ 对应的Q值中的最大值。这表明Q - learning在更新时总是朝着最优动作的方向进行,具有贪心的特性。

SARSA:更新公式为 $$Q(s,a) \leftarrow Q(s,a)+\alpha\left[r + \gamma Q(s',a')-Q(s,a)\right]$$ 这里的 $a'$ 是在状态 $s'$ 下根据当前策略($\epsilon$-贪心策略)实际选择的动作。与Q - learning不同,SARSA使用的是实际执行的动作 $a'$ 对应的Q值 $Q(s',a')$ 来更新当前状态 - 动作对 $(s,a)$ 的Q值。

 

探索与利用的平衡

Q - learning:由于其离线策略的特性,在探索过程中,即使采取了一些随机动作,也会尝试朝着最优策略的方向更新Q值。这可能导致在某些情况下,它会更激进地去探索可能的最优路径,甚至可能会走入一些危险区域(在环境中存在风险的情况下),因为它总是期望找到全局最优解。

SARSA:在线策略使得它会更加谨慎地对待探索。因为它的更新是基于实际执行的动作序列,所以在遇到一些可能有风险的动作时,会通过更新Q值来避免再次选择类似的动作,从而表现得更加保守。

 

收敛性与稳定性

Q - learning:在一定条件下可以收敛到最优的Q值函数 $Q^*(s,a)$,从而得到最优策略。但由于其贪心更新的特点,在某些复杂环境中可能会出现收敛速度较慢或者陷入局部最优的情况

SARSA:同样在满足一定条件下可以收敛,但由于它考虑了实际执行的动作,其收敛过程相对更加稳定,对环境中的噪声和不确定性有一定的鲁棒性。不过,它收敛到的可能不是严格意义上的最优策略,而是与当前采用的策略(如 $\epsilon$-贪心策略)相关的一个较好的策略。

 

适用场景

Q - learning:适用于环境相对稳定、奖励机制明确,且希望快速收敛到最优策略的场景。例如在一些棋盘游戏、确定性的路径规划问题中,Q - learning可以较好地发挥作用。

SARSA:更适合于环境中存在一定风险,需要谨慎探索的场景。比如在机器人导航任务中,如果环境中存在障碍物或者危险区域,SARSA可以学习到更加安全可靠的导航策略。

 

import gym
import numpy as np
import matplotlib.pyplot as plt

# 定义 SARSA 算法类
class SARSA:
    def __init__(self, env, learning_rate=0.1, discount_factor=0.9, epsilon=0.1):
        # 初始化环境
        self.env = env
        # 初始化学习率
        self.learning_rate = learning_rate
        # 初始化折扣因子
        self.discount_factor = discount_factor
        # 初始化探索率
        self.epsilon = epsilon
        # 获取环境的状态数
        self.state_space = env.observation_space.n
        # 获取环境的动作数
        self.action_space = env.action_space.n
        # 初始化 Q 表,用于存储状态 - 动作值
        self.q_table = np.zeros((self.state_space, self.action_space))

    def choose_action(self, state):
        # 以 epsilon 的概率进行探索,随机选择一个动作
        if np.random.uniform(0, 1) < self.epsilon:
            action = self.env.action_space.sample()
        else:
            # 以 1 - epsilon 的概率进行利用,选择 Q 表中当前状态下价值最大的动作
            action = np.argmax(self.q_table[state, :])
        return action

    def update(self, state, action, reward, next_state, next_action):
        # 根据 SARSA 算法更新 Q 表
        current_q = self.q_table[state, action]
        next_q = self.q_table[next_state, next_action]
        self.q_table[state, action] += self.learning_rate * (reward + self.discount_factor * next_q - current_q)

    def train(self, num_episodes):
        # 存储每一轮的总奖励
        rewards = []
        for episode in range(num_episodes):
            # 重置环境,获取初始状态
            state, _ = self.env.reset()  # 提取初始状态
            # 选择初始动作
            action = self.choose_action(state)
            total_reward = 0
            while True:
                # 执行动作,获取下一个状态、奖励、是否终止的标志和其他信息
                next_state, reward, done, _, _ = self.env.step(action)  # 提取下一个状态
                # 选择下一个动作
                next_action = self.choose_action(next_state)
                # 更新 Q 表
                self.update(state, action, reward, next_state, next_action)
                state = next_state
                action = next_action
                total_reward += reward
                if done:
                    break
            rewards.append(total_reward)
            if episode % 100 == 0:
                print(f"Episode {episode}: Total Reward = {total_reward}")
        return rewards

# 创建 CliffWalking-v0 环境
env = gym.make('CliffWalking-v0')
# 初始化 SARSA 算法实例
sarsa = SARSA(env)
# 训练 500 轮
num_episodes = 500
rewards = sarsa.train(num_episodes)

# 绘制每一轮的总奖励曲线
plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('SARSA Training in CliffWalking-v0')
plt.show()

# 关闭环境
env.close()

 

posted @ 2025-02-28 14:56  AI_Engineer  阅读(324)  评论(0)    收藏  举报