stable_baseline3 快速入门(一): 训练第一个强化学习模型
简介
stable_baseline3 是一个基于 PyTorch 的强化学习算法开源库,里面集成了多种强化学习算法,使用这个开源库能够让我们不需要过度关注强化学习算法细节,专注于AI业务的开发。
环境配置
pip install stable-baselines3
pip install gymnasium
这里stable-baselines3会默认安装pytroch框架,但是是不带cuda版本的,这就意味着我们无法利用我们的显卡对模型进行训练。
下载cuda版本的pytroch步骤如下:
- 卸载原来版本的
pytroch框架
pip uninstall torch torchvision torchaudio -y
#这个是针对RTX 30/40/50显卡的。
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
如果其他版本请参考官网: https://pytorch.org/get-started/locally/
认识stable_baseline3
stable_baseline3提供了许多模型,如下列表:
| 名称 | 动作空间 | 建议应用场景 | 核心优势 |
|---|---|---|---|
| PPO | 连续 & 离散 | 全能选手,如机器人走动、金融交易、游戏 AI | 极其稳定,对超参数不敏感,支持大规模并行训练。 |
| DQN | 仅离散 | 经典游戏(Atari)、开关控制、迷宫寻路 | 理解简单,在离散控制领域非常经典且有效。 |
| SAC | 仅连续 | 复杂物理模拟、机械臂抓取、自动驾驶 | 探索效率极高,能自动寻找最优路径且不轻易陷入局部最优。 |
| TD3 | 仅连续 | 工业控制、无人机飞行、精密动作 | 针对 DDPG 的缺陷做了改进,训练过程比 SAC 更平滑。 |
| A2C | 连续 & 离散 | 简单逻辑测试、快速原型验证 | 结构简单,虽然不如 PPO 稳定,但在特定并行环境下速度极快。 |
在声明模型中,可以设置多种参数,这里列出常用的:
目前不需要搞懂都有什么作用,后面有文章会详细讲解
- 训练参数
learning_rate:学习率gamma:折扣因子batch_size:更新模型使用数据量verbose:打印信息模式。0-静默模式,1-信息模式,2-调试模式device:指定训练设备cuda使用显卡,cpu使用cpu
- 模型规则
MlpPolicy:多层感知机。适用于状态是数值场景(传感器等)CnnPolicy:卷积神经网络。适用于状态是图像场景(游戏等)
训练第一个强化学习模型
案例
案例描述:训练一个gymnasium默认提供的游戏环境,平衡杆游戏。
import gymnasium as gym
from stable_baselines3 import PPO
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1, device="cuda")
print("开始训练...")
model.learn(total_timesteps=10000)
print("正在保存模型...")
model.save("ppo_cartpole")
print("正在读取模型...")
env = gym.make("CartPole-v1", render_mode="human")
loaded_model = PPO.load("ppo_cartpole", env=env)
print("训练结束,开始演示...")
obs, _ = env.reset()
for i in range(1000):
action, _states = loaded_model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
env.close()
代码解释
代码流程如下:
初始化环境模型->训练模型->保存模型->加载模型->模型预测
初始化环境模型
初始化模型以及游戏的环境
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1, device="cuda")
env = gym.make("CartPole-v1", render_mode="human")
gym中的make方法利用默认的游戏环境,CartPole-v1是游戏名,下面有一个render_mode="human"参数,用于标识是否展示画面。训练时展示画面会降低训练的速度,一般在预测时才使用
训练模型
model.learn(total_timesteps=10000)
total_timesteps:训练10000次
保存模型
model.save("ppo_cartpole")
"ppo_cartpole"为保存模型的名字,这里是保存在当前文件夹中。
加载模型
loaded_model = PPO.load("ppo_cartpole", env=env)
- 第一个参数:刚刚保存的模型路径
- 第二个参数:训练的环境
模型预测
obs, _ = env.reset()
for i in range(1000):
action, _states = loaded_model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
env.reset()重置环境,返回初始观测值obs和info(这里没用到)- 模型的
predict方法用于根据观测值obs预测下一步行动。注意:deterministic参数要为True,不然会报错 - 模型的
step方法根据行动值返回结果。(这些都是什么后面文章会讲)
如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~

浙公网安备 33010602011771号