stable_baseline3 快速入门(一): 训练第一个强化学习模型

简介

stable_baseline3 是一个基于 PyTorch 的强化学习算法开源库,里面集成了多种强化学习算法,使用这个开源库能够让我们不需要过度关注强化学习算法细节,专注于AI业务的开发。

环境配置

pip install stable-baselines3
pip install gymnasium

这里stable-baselines3会默认安装pytroch框架,但是是不带cuda版本的,这就意味着我们无法利用我们的显卡对模型进行训练。
下载cuda版本的pytroch步骤如下:

  1. 卸载原来版本的pytroch框架
pip uninstall torch torchvision torchaudio -y
#这个是针对RTX 30/40/50显卡的。
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

如果其他版本请参考官网: https://pytorch.org/get-started/locally/

认识stable_baseline3

stable_baseline3提供了许多模型,如下列表:

名称 动作空间 建议应用场景 核心优势
PPO 连续 & 离散 全能选手,如机器人走动、金融交易、游戏 AI 极其稳定,对超参数不敏感,支持大规模并行训练。
DQN 仅离散 经典游戏(Atari)、开关控制、迷宫寻路 理解简单,在离散控制领域非常经典且有效。
SAC 仅连续 复杂物理模拟、机械臂抓取、自动驾驶 探索效率极高,能自动寻找最优路径且不轻易陷入局部最优。
TD3 仅连续 工业控制、无人机飞行、精密动作 针对 DDPG 的缺陷做了改进,训练过程比 SAC 更平滑。
A2C 连续 & 离散 简单逻辑测试、快速原型验证 结构简单,虽然不如 PPO 稳定,但在特定并行环境下速度极快。

声明模型中,可以设置多种参数,这里列出常用的:
目前不需要搞懂都有什么作用,后面有文章会详细讲解

  1. 训练参数
  • learning_rate:学习率
  • gamma:折扣因子
  • batch_size:更新模型使用数据量
  • verbose:打印信息模式。0-静默模式,1-信息模式,2-调试模式
  • device:指定训练设备cuda使用显卡,cpu使用cpu
  1. 模型规则
  • MlpPolicy:多层感知机。适用于状态是数值场景(传感器等)
  • CnnPolicy:卷积神经网络。适用于状态是图像场景(游戏等)

训练第一个强化学习模型

案例

案例描述:训练一个gymnasium默认提供的游戏环境,平衡杆游戏。

import gymnasium as gym
from stable_baselines3 import PPO

env = gym.make("CartPole-v1")

model = PPO("MlpPolicy", env, verbose=1, device="cuda")

print("开始训练...")
model.learn(total_timesteps=10000)

print("正在保存模型...")
model.save("ppo_cartpole")

print("正在读取模型...")
env = gym.make("CartPole-v1", render_mode="human")
loaded_model = PPO.load("ppo_cartpole", env=env)

print("训练结束,开始演示...")
obs, _ = env.reset()
for i in range(1000):
    action, _states = loaded_model.predict(obs, deterministic=True)

    obs, reward, terminated, truncated, info = env.step(action)
    
    if terminated or truncated:
        obs, _ = env.reset()

env.close()

代码解释

代码流程如下:
初始化环境模型->训练模型->保存模型->加载模型->模型预测

初始化环境模型

初始化模型以及游戏的环境

env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1, device="cuda")

env = gym.make("CartPole-v1", render_mode="human")
  • gym中的make方法利用默认的游戏环境,CartPole-v1是游戏名,下面有一个render_mode="human"参数,用于标识是否展示画面。训练时展示画面会降低训练的速度,一般在预测时才使用
训练模型
model.learn(total_timesteps=10000)
  • total_timesteps:训练10000次
保存模型
model.save("ppo_cartpole")
  • "ppo_cartpole" 为保存模型的名字,这里是保存在当前文件夹中。
加载模型
loaded_model = PPO.load("ppo_cartpole", env=env)
  • 第一个参数:刚刚保存的模型路径
  • 第二个参数:训练的环境
模型预测
obs, _ = env.reset()
for i in range(1000):
    action, _states = loaded_model.predict(obs, deterministic=True)

    obs, reward, terminated, truncated, info = env.step(action)
    
    if terminated or truncated:
        obs, _ = env.reset()
  • env.reset()重置环境,返回初始观测值obsinfo(这里没用到)
  • 模型的predict方法用于根据观测值obs预测下一步行动。注意:deterministic参数要为True,不然会报错
  • 模型的step方法根据行动值返回结果。(这些都是什么后面文章会讲)

如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~

posted @ 2026-01-29 16:01  ClownLMe  阅读(0)  评论(0)    收藏  举报