18个常用的强化学习算法整理:从基础方法到高级模型的理论技术与代码实现
本文系统讲解从基本强化学习方法到高级技术(如PPO、A3C、PlaNet等)的实现原理与编码过程,旨在通过理论结合代码的方式,构建对强化学习算法的全面理解。
为确保内容易于理解和实践,全部代码均在Jupyter Notebook环境中实现,仅依赖基础库进行算法构建。
代码库组织结构如下:
├── 1_simple_rl.ipynb├── 2_q_learning.ipynb├── 3_sarsa.ipynb...├── 9_a3c.ipynb├── 10_ddpg.ipynb├── 11_sac.ipynb├── 12_trpo.ipynb...├── 17_mcts.ipynb└── 18_planet.ipynb
说明:github地址见文章最后,文章很长所以可以根据需求查看感兴趣的强化学习方法介绍和对应notebook。
搭建环境
首先,需要克隆仓库并安装相关依赖项:
# 克隆并导航到目录git clone https://github.com/fareedkhan-dev/all-rl-algorithms.gitcd all-rl-algorithms# 安装所需的依赖项pip install -r requirements.txt
接下来,导入核心库:
# --- 核心Python库 ---import randomimport mathfrom collections import defaultdict, deque, namedtuplefrom typing import List, Tuple, Dict, Optional, Any, DefaultDict # 用于代码中的类型提示# --- 数值计算 ---import numpy as np# --- 机器学习框架(PyTorch - 从REINFORCE开始广泛使用) ---import torchimport torch.nn as nnimport torch.optim as optimimport torch.nn.functional as Ffrom torch.distributions import Categorical, Normal # 用于策略梯度、SAC、PlaNet等# --- 环境 ---# 用于加载标准环境,如Pendulumimport gymnasium as gym# 注意:SimpleGridWorld类定义需要直接包含在代码中# 因为它是博客文章中定义的自定义环境。# --- 可视化(由博客中显示的图表暗示) ---import matplotlib.pyplot as pltimport seaborn as sns # 经常用于热力图# --- 可能用于异步方法(A3C) ---# 尽管在代码片段中没有明确展示,但A3C实现通常使用这些# import torch.multiprocessing as mp # 或标准的'multiprocessing'/'threading'# --- PyTorch设置(可选但是好习惯) ---device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# --- 禁用警告(可选) ---import warningswarnings.filterwarnings('ignore') # 抑制潜在的废弃警告等
强化学习环境设置
虽然OpenAI Gym库提供了常见的强化学习环境,但为了深入理解算法核心原理,我们将自行实现大部分环境。仅在少数需要特殊环境配置的算法中,才会使用Gym模块。
本文主要关注两个环境:
- 自定义网格世界(从头实现)
- 钟摆问题(使用OpenAI Gymnasium)
# -------------------------------------# 1. 简单自定义网格世界# -------------------------------------class SimpleGridWorld:""" 一个基本的网格世界环境。 """def __init__(self, size=5):self.size = sizeself.start_state = (0, 0)self.goal_state = (size - 1, size - 1)self.state = self.start_state# 动作: 0:上, 1:下, 2:左, 3:右self.action_map = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}self.action_space_size = 4def reset(self) -> Tuple[int, int]:""" 重置到初始状态。 """self.state = self.start_statereturn self.statedef step(self, action: int) -> Tuple[Tuple[int, int], float, bool]:""" 执行一个动作,返回next_state, reward, done。 """if self.state == self.goal_state:return self.state, 0.0, True # 在目标处停留# 计算潜在的下一个状态dr, dc = self.action_map[action]r, c = self.statenext_r, next_c = r + dr, c + dc# 应用边界(如果碰到墙壁则原地不动)if not (0 <= next_r < self.size and 0 <= next_c < self.size):next_r, next_c = r, c # 保持在当前状态reward = -1.0 # 墙壁惩罚else:reward = -0.1 # 步骤成本# 更新状态self.state = (next_r, next_c)# 检查是否达到目标done = (self.state == self.goal_state)if done:reward = 10.0 # 目标奖励return self.state, reward, done
SimpleGridWorld
环境是一个基础的二维网格强化学习环境,智能体需要从起始位置
(0,0)
导航至目标位置
(size-1, size-1)
。智能体可以执行四个基本方向的移动动作(上、下、左、右),在每一步会接收一个小的步骤惩罚(-0.1),碰撞墙壁则会获得更大的惩罚(-1.0),而到达目标则给予较大的奖励(10.0)。
# -------------------------------------# 2. 加载Gymnasium钟摆# -------------------------------------pendulum_env = gym.make('Pendulum-v1')print("Pendulum-v1 environment loaded.")# 重置环境observation, info = pendulum_env.reset(seed=42)print(f"Initial Observation: {observation}")print(f"Observation Space: {pendulum_env.observation_space}")print(f"Action Space: {pendulum_env.action_space}")# 执行随机步骤random_action = pendulum_env.action_space.sample()observation, reward, terminated, truncated, info = pendulum_env.step(random_action)done = terminated or truncatedprint(f"Step with action {random_action}:")print(f" Next Obs: {observation}\n Reward: {reward}\n Done: {done}")# 关闭环境(如果使用了渲染则很重要)pendulum_env.close()
对于钟摆问题,我们使用Gymnasium库中的
Pendulum-v1
环境,这是一个基于物理的连续控制任务。上述代码初始化环境并展示了基本交互过程,包括获取初始观察、显示观察空间和动作空间的结构,以及执行一个随机动作并处理反馈。
让我们可视化这两个环境:
网格世界和钟摆
从上图可以看出,在网格世界环境中,智能体的目标是找到从起点到目标的最短路径;而在钟摆环境中,目标是将摆杆从任意初始位置控制到竖直向上的平衡点。
https://avoid.overfit.cn/post/2785f28f0c094343a5b51466a1df29cf

浙公网安备 33010602011771号