18个常用的强化学习算法整理:从基础方法到高级模型的理论技术与代码实现
本文系统讲解从基本强化学习方法到高级技术(如PPO、A3C、PlaNet等)的实现原理与编码过程,旨在通过理论结合代码的方式,构建对强化学习算法的全面理解。
为确保内容易于理解和实践,全部代码均在Jupyter Notebook环境中实现,仅依赖基础库进行算法构建。
代码库组织结构如下:
├── 1_simple_rl.ipynb
├── 2_q_learning.ipynb
├── 3_sarsa.ipynb
...
├── 9_a3c.ipynb
├── 10_ddpg.ipynb
├── 11_sac.ipynb
├── 12_trpo.ipynb
...
├── 17_mcts.ipynb
└── 18_planet.ipynb
说明:github地址见文章最后,文章很长所以可以根据需求查看感兴趣的强化学习方法介绍和对应notebook。
搭建环境
首先,需要克隆仓库并安装相关依赖项:
# 克隆并导航到目录
git clone https://github.com/fareedkhan-dev/all-rl-algorithms.git
cd all-rl-algorithms
# 安装所需的依赖项
pip install -r requirements.txt
接下来,导入核心库:
# --- 核心Python库 ---
import random
import math
from collections import defaultdict, deque, namedtuple
from typing import List, Tuple, Dict, Optional, Any, DefaultDict # 用于代码中的类型提示
# --- 数值计算 ---
import numpy as np
# --- 机器学习框架(PyTorch - 从REINFORCE开始广泛使用) ---
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical, Normal # 用于策略梯度、SAC、PlaNet等
# --- 环境 ---
# 用于加载标准环境,如Pendulum
import gymnasium as gym
# 注意:SimpleGridWorld类定义需要直接包含在代码中
# 因为它是博客文章中定义的自定义环境。
# --- 可视化(由博客中显示的图表暗示) ---
import matplotlib.pyplot as plt
import seaborn as sns # 经常用于热力图
# --- 可能用于异步方法(A3C) ---
# 尽管在代码片段中没有明确展示,但A3C实现通常使用这些
# import torch.multiprocessing as mp # 或标准的'multiprocessing'/'threading'
# --- PyTorch设置(可选但是好习惯) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# --- 禁用警告(可选) ---
import warnings
warnings.filterwarnings('ignore') # 抑制潜在的废弃警告等
强化学习环境设置
虽然OpenAI Gym库提供了常见的强化学习环境,但为了深入理解算法核心原理,我们将自行实现大部分环境。仅在少数需要特殊环境配置的算法中,才会使用Gym模块。
本文主要关注两个环境:
- 自定义网格世界(从头实现)
- 钟摆问题(使用OpenAI Gymnasium)
# -------------------------------------
# 1. 简单自定义网格世界
# -------------------------------------
class SimpleGridWorld:
""" 一个基本的网格世界环境。 """
def __init__(self, size=5):
self.size = size
self.start_state = (0, 0)
self.goal_state = (size - 1, size - 1)
self.state = self.start_state
# 动作: 0:上, 1:下, 2:左, 3:右
self.action_map = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}
self.action_space_size = 4
def reset(self) -> Tuple[int, int]:
""" 重置到初始状态。 """
self.state = self.start_state
return self.state
def step(self, action: int) -> Tuple[Tuple[int, int], float, bool]:
""" 执行一个动作,返回next_state, reward, done。 """
if self.state == self.goal_state:
return self.state, 0.0, True # 在目标处停留
# 计算潜在的下一个状态
dr, dc = self.action_map[action]
r, c = self.state
next_r, next_c = r + dr, c + dc
# 应用边界(如果碰到墙壁则原地不动)
if not (0 <= next_r < self.size and 0 <= next_c < self.size):
next_r, next_c = r, c # 保持在当前状态
reward = -1.0 # 墙壁惩罚
else:
reward = -0.1 # 步骤成本
# 更新状态
self.state = (next_r, next_c)
# 检查是否达到目标
done = (self.state == self.goal_state)
if done:
reward = 10.0 # 目标奖励
return self.state, reward, done
SimpleGridWorld
环境是一个基础的二维网格强化学习环境,智能体需要从起始位置
(0,0)
导航至目标位置
(size-1, size-1)
。智能体可以执行四个基本方向的移动动作(上、下、左、右),在每一步会接收一个小的步骤惩罚(-0.1),碰撞墙壁则会获得更大的惩罚(-1.0),而到达目标则给予较大的奖励(10.0)。
# -------------------------------------
# 2. 加载Gymnasium钟摆
# -------------------------------------
pendulum_env = gym.make('Pendulum-v1')
print("Pendulum-v1 environment loaded.")
# 重置环境
observation, info = pendulum_env.reset(seed=42)
print(f"Initial Observation: {observation}")
print(f"Observation Space: {pendulum_env.observation_space}")
print(f"Action Space: {pendulum_env.action_space}")
# 执行随机步骤
random_action = pendulum_env.action_space.sample()
observation, reward, terminated, truncated, info = pendulum_env.step(random_action)
done = terminated or truncated
print(f"Step with action {random_action}:")
print(f" Next Obs: {observation}\n Reward: {reward}\n Done: {done}")
# 关闭环境(如果使用了渲染则很重要)
pendulum_env.close()
对于钟摆问题,我们使用Gymnasium库中的
Pendulum-v1
环境,这是一个基于物理的连续控制任务。上述代码初始化环境并展示了基本交互过程,包括获取初始观察、显示观察空间和动作空间的结构,以及执行一个随机动作并处理反馈。
让我们可视化这两个环境:
网格世界和钟摆
从上图可以看出,在网格世界环境中,智能体的目标是找到从起点到目标的最短路径;而在钟摆环境中,目标是将摆杆从任意初始位置控制到竖直向上的平衡点。
https://avoid.overfit.cn/post/2785f28f0c094343a5b51466a1df29cf