18个常用的强化学习算法整理:从基础方法到高级模型的理论技术与代码实现

本文系统讲解从基本强化学习方法到高级技术(如PPO、A3C、PlaNet等)的实现原理与编码过程,旨在通过理论结合代码的方式,构建对强化学习算法的全面理解。

为确保内容易于理解和实践,全部代码均在Jupyter Notebook环境中实现,仅依赖基础库进行算法构建。

代码库组织结构如下:

  1. ├── 1_simple_rl.ipynb
  2. ├── 2_q_learning.ipynb
  3. ├── 3_sarsa.ipynb
  4. ...
  5. ├── 9_a3c.ipynb
  6. ├── 10_ddpg.ipynb
  7. ├── 11_sac.ipynb
  8. ├── 12_trpo.ipynb
  9. ...
  10. ├── 17_mcts.ipynb
  11. └── 18_planet.ipynb

说明:github地址见文章最后,文章很长所以可以根据需求查看感兴趣的强化学习方法介绍和对应notebook。

搭建环境

首先,需要克隆仓库并安装相关依赖项:

  1. # 克隆并导航到目录
  2. git clone https://github.com/fareedkhan-dev/all-rl-algorithms.git
  3. cd all-rl-algorithms
  4. # 安装所需的依赖项
  5. pip install -r requirements.txt

接下来,导入核心库:

  1. # --- 核心Python库 ---
  2. import random
  3. import math
  4. from collections import defaultdict, deque, namedtuple
  5. from typing import List, Tuple, Dict, Optional, Any, DefaultDict # 用于代码中的类型提示
  6. # --- 数值计算 ---
  7. import numpy as np
  8. # --- 机器学习框架(PyTorch - 从REINFORCE开始广泛使用) ---
  9. import torch
  10. import torch.nn as nn
  11. import torch.optim as optim
  12. import torch.nn.functional as F
  13. from torch.distributions import Categorical, Normal # 用于策略梯度、SAC、PlaNet等
  14. # --- 环境 ---
  15. # 用于加载标准环境,如Pendulum
  16. import gymnasium as gym
  17. # 注意:SimpleGridWorld类定义需要直接包含在代码中
  18. # 因为它是博客文章中定义的自定义环境。
  19. # --- 可视化(由博客中显示的图表暗示) ---
  20. import matplotlib.pyplot as plt
  21. import seaborn as sns # 经常用于热力图
  22. # --- 可能用于异步方法(A3C) ---
  23. # 尽管在代码片段中没有明确展示,但A3C实现通常使用这些
  24. # import torch.multiprocessing as mp # 或标准的'multiprocessing'/'threading'
  25. # --- PyTorch设置(可选但是好习惯) ---
  26. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  27. print(f"Using device: {device}")
  28. # --- 禁用警告(可选) ---
  29. import warnings
  30. warnings.filterwarnings('ignore') # 抑制潜在的废弃警告等

强化学习环境设置

虽然OpenAI Gym库提供了常见的强化学习环境,但为了深入理解算法核心原理,我们将自行实现大部分环境。仅在少数需要特殊环境配置的算法中,才会使用Gym模块。

本文主要关注两个环境:

  1. 自定义网格世界(从头实现)
  2. 钟摆问题(使用OpenAI Gymnasium)
  1. # -------------------------------------
  2. # 1. 简单自定义网格世界
  3. # -------------------------------------
  4. class SimpleGridWorld:
  5. """ 一个基本的网格世界环境。 """
  6. def __init__(self, size=5):
  7. self.size = size
  8. self.start_state = (0, 0)
  9. self.goal_state = (size - 1, size - 1)
  10. self.state = self.start_state
  11. # 动作: 0:上, 1:下, 2:左, 3:右
  12. self.action_map = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}
  13. self.action_space_size = 4
  14. def reset(self) -> Tuple[int, int]:
  15. """ 重置到初始状态。 """
  16. self.state = self.start_state
  17. return self.state
  18. def step(self, action: int) -> Tuple[Tuple[int, int], float, bool]:
  19. """ 执行一个动作,返回next_state, reward, done。 """
  20. if self.state == self.goal_state:
  21. return self.state, 0.0, True # 在目标处停留
  22. # 计算潜在的下一个状态
  23. dr, dc = self.action_map[action]
  24. r, c = self.state
  25. next_r, next_c = r + dr, c + dc
  26. # 应用边界(如果碰到墙壁则原地不动)
  27. if not (0 <= next_r < self.size and 0 <= next_c < self.size):
  28. next_r, next_c = r, c # 保持在当前状态
  29. reward = -1.0 # 墙壁惩罚
  30. else:
  31. reward = -0.1 # 步骤成本
  32. # 更新状态
  33. self.state = (next_r, next_c)
  34. # 检查是否达到目标
  35. done = (self.state == self.goal_state)
  36. if done:
  37. reward = 10.0 # 目标奖励
  38. return self.state, reward, done
  1. SimpleGridWorld

环境是一个基础的二维网格强化学习环境,智能体需要从起始位置

  1. (0,0)

导航至目标位置

  1. (size-1, size-1)

。智能体可以执行四个基本方向的移动动作(上、下、左、右),在每一步会接收一个小的步骤惩罚(-0.1),碰撞墙壁则会获得更大的惩罚(-1.0),而到达目标则给予较大的奖励(10.0)。

  1. # -------------------------------------
  2. # 2. 加载Gymnasium钟摆
  3. # -------------------------------------
  4. pendulum_env = gym.make('Pendulum-v1')
  5. print("Pendulum-v1 environment loaded.")
  6. # 重置环境
  7. observation, info = pendulum_env.reset(seed=42)
  8. print(f"Initial Observation: {observation}")
  9. print(f"Observation Space: {pendulum_env.observation_space}")
  10. print(f"Action Space: {pendulum_env.action_space}")
  11. # 执行随机步骤
  12. random_action = pendulum_env.action_space.sample()
  13. observation, reward, terminated, truncated, info = pendulum_env.step(random_action)
  14. done = terminated or truncated
  15. print(f"Step with action {random_action}:")
  16. print(f" Next Obs: {observation}\n Reward: {reward}\n Done: {done}")
  17. # 关闭环境(如果使用了渲染则很重要)
  18. pendulum_env.close()

对于钟摆问题,我们使用Gymnasium库中的

  1. Pendulum-v1

环境,这是一个基于物理的连续控制任务。上述代码初始化环境并展示了基本交互过程,包括获取初始观察、显示观察空间和动作空间的结构,以及执行一个随机动作并处理反馈。

让我们可视化这两个环境:

网格世界和钟摆

从上图可以看出,在网格世界环境中,智能体的目标是找到从起点到目标的最短路径;而在钟摆环境中,目标是将摆杆从任意初始位置控制到竖直向上的平衡点。

 

https://avoid.overfit.cn/post/2785f28f0c094343a5b51466a1df29cf

posted @ 2025-04-11 09:59  deephub  阅读(45)  评论(0)    收藏  举报