stable_baseline3 快速入门(二): 训练自定义游戏,构建Gymnasium训练环境
简介
Gymnasium 为强化学习提供了一个标准化的API,它定义了 Agent 应该如何观察世界、如何做出动作以及如何获得奖励,不管是游戏,还是工业设备,只需要满足Gymnasium标准都能使用同一套代码进行训练。
认识Gymnasium
使用stable_baseline3只需要定义好Gymnasium环境,关注训练的奖励机制,将重点放在业务的开发上而不是复杂的算法。
Gymnasium提供了几个核心的api:
| 方法 | 功能 | 返回值 |
|---|---|---|
reset() |
将环境重置为初始状态,开始新回合。 | obs, info |
step(action) |
环境向前推进一步,执行动作。 | obs, reward, terminated, truncated, info |
render() |
可视化环境(根据 render_mode 渲染图像或弹出窗口)。 |
视配置而定(通常无或为 np.array) |
close() |
释放环境资源(关闭窗口、清理内存)。 | 无 |
其中的各个返回值的含义:
observation(Object): 当前状态的描述。例如敌人,玩家的位置,玩家的状态等reward(Float): 上一步动作获得的奖励terminated(Bool): 是否由于任务逻辑结束。例如:到达终点、掉进岩浆等truncated(Bool): 是否由于外部限制结束。例如:达到最大步数 500 步info(Dict): 辅助诊断信息,模型训练通常不用,用于用户自定义调试或记录额外统计。
手动构建环境
案例
案例描述:利用pygame构建一个简单的游戏,躲避掉落方块,利用构建的奖励机制,进行强化学习。
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pygame
import random
import cv2
import os
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_checker import check_env
class MyEnv(gym.Env):
def __init__(self, render_mode=None):
super(MyEnv, self).__init__()
#初始化参数
self.width = 400
self.height = 300
self.player_size = 30
self.enemy_size = 30
self.render_mode = render_mode
self.action_space = spaces.Discrete(3)
self.observation_space = spaces.Box(
low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
)
pygame.init()
if self.render_mode == "human":
self.screen = pygame.display.set_mode((self.width, self.height))
self.canvas = pygame.Surface((self.width, self.height))
self.font = pygame.font.SysFont("monospace", 15)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.player_x = self.width // 2 - self.player_size // 2
self.player_y = self.height - self.player_size - 10
self.enemies = []
self.score = 0
self.frame_count = 0
self.current_speed = 5
self.spawn_rate = 30
return self._get_obs(), {}
def step(self, action):
reward = 0
terminated = False
truncated = False
move_speed = 8
if action == 1 and self.player_x > 0: #
self.player_x -= move_speed
reward -= 0.05
if action == 2 and self.player_x < self.width - self.player_size:
self.player_x += move_speed
reward -= 0.05
self.frame_count += 1
level = self.score // 5
self.current_speed = 5 + level
self.spawn_rate = 30 - level * 2
spawn_rate = max(10, 30 - level)
if self.frame_count >= spawn_rate:
self.frame_count = 0
enemy_x = random.randint(0, self.width - self.enemy_size)
self.enemies.append([enemy_x, 0]) # [x, y]
for enemy in self.enemies:
enemy[1] += self.current_speed
player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)
enemy_rect = pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size)
if player_rect.colliderect(enemy_rect):
reward = -10
terminated = True
elif enemy[1] > self.height:
self.enemies.remove(enemy)
self.score += 1
reward = 1
if not terminated:
if self.score > 100:
reward += 0.01
reward += 0.01
obs = self._get_obs()
if self.render_mode == "human":
self._render_window()
return obs, reward, terminated, truncated, {}
def _get_obs(self):
self.canvas.fill((0, 0, 0))
pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))
for enemy in self.enemies:
pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size))
img_array = pygame.surfarray.array3d(self.canvas)
img_array = np.transpose(img_array, (1, 0, 2))
obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)
return obs.astype(np.uint8)
def _render_window(self):
self.screen.blit(self.canvas, (0, 0))
text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
self.screen.blit(text, (10, 10))
pygame.display.flip()
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
def train():
log_dir = "logs/DodgeGame"
os.makedirs(log_dir, exist_ok=True)
env = MyEnv()
check_env(env)
print("环境检查通过...")
model_path = "models/dodge_ai.zip"
if not os.path.exists(model_path):
print("🆕 未发现旧模型,从头开始训练...")
model = PPO(
"CnnPolicy",
env,
verbose=1,
tensorboard_log=log_dir,
learning_rate=0.0001,
n_steps=4096,
batch_size=256,
device="cuda")
reset_timesteps = True
else:
print("发现旧模型,加载并继续训练...")
model = PPO.load(
model_path,
env=env,
device="cuda",
custom_objects={"learning_rate": 0.0001, "n_steps": 4096, "batch_size": 256}
)
reset_timesteps = False
print("开始训练...")
model.learn(
total_timesteps=50000,
reset_num_timesteps=reset_timesteps
)
model.save("models/dodge_ai")
print("模型已保存!")
env.close()
def prodict():
env = MyEnv(render_mode="human")
model = PPO.load("models/dodge_ai", env=env, device="cuda")
obs, _ = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
pygame.time.Clock().tick(30)
if __name__ == "__main__":
train()
prodict()
代码解析
代码流程如下:
构建游戏环境->训练模型->模型预测
本篇重点讲构建游戏环境,其中的pygame相关代码简略,另外两个流程参考之前文章。
构建游戏环境
初始化类
该类继承gym.Env类
class MyEnv(gym.Env):
构造函数__init__
def __init__(self, render_mode=None):
super(MyEnv, self).__init__()
#初始化参数
self.width = 400
self.height = 300
self.player_size = 30
self.enemy_size = 30
self.render_mode = render_mode
self.action_space = spaces.Discrete(3)
self.observation_space = spaces.Box(
low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
)
pygame.init()
if self.render_mode == "human":
self.screen = pygame.display.set_mode((self.width, self.height))
self.canvas = pygame.Surface((self.width, self.height))
self.font = pygame.font.SysFont("monospace", 15)
在构造函数中,我们主要完成的是声明训练的维度,和输入:
- 输入:
self.action_space = spaces.Discrete(3)其中的self.action_space是固定名称的父类变量。spaces.Discrete(3)声明输入的数量,例如:向左 向右 和 不动3个输入。 - 观测维度:
self.observation_space也是固定名称的父类变量。spaces.Box声明观测维度。
self.observation_space = spaces.Box(
low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
)
low:观测参数的最小值high:观测参数的最大值shape:声明维度。例如:观测图片shape(高,宽,RGB),观测一个平面,shape(高,宽)dtype:每个变量类型,这里选np.uint8能够节省训练成本,默认是浮点型的。
任务重置 reset
相当于初始化游戏状态,游戏的重新开始。返回的是观测值和状态信息(用于调试日志)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.player_x = self.width // 2 - self.player_size // 2
self.player_y = self.height - self.player_size - 10
self.enemies = []
self.score = 0
self.frame_count = 0
self.current_speed = 5
self.spawn_rate = 30
return self._get_obs(), {}
观测值 _get_obs:
通过pygame画出的画面,然后用opencv进行简单处理:
- 转换坐标轴(由于
opencv坐标xy轴跟pygame的xy是颠倒的) - 将画面缩放到
84 * 84(可以提高训练效率)
def _get_obs(self):
self.canvas.fill((0, 0, 0))
pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))
for enemy in self.enemies:
pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size))
img_array = pygame.surfarray.array3d(self.canvas)
img_array = np.transpose(img_array, (1, 0, 2))
obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)
return obs.astype(np.uint8)
步 step(重要)
这个函数是强化训练的核心,规定了在一帧或者一步,我们给AI的分数。
分数的设置至关重要,这直接决定了训练出来AI的质量
根据下面代码(大部分都是游戏逻辑),主要讲设置奖励分数:
- 在AI进行移动时 惩罚 0.05 分
- 在AI存活时 奖励 0.01分,游戏分数大于100时 存活奖励 0.02分
- 在障碍物完全下落时 奖励 1 分
- 在与障碍物碰撞时 惩罚 10 分
def step(self, action):
reward = 0
terminated = False
truncated = False
move_speed = 8
if action == 1 and self.player_x > 0: #
self.player_x -= move_speed
reward -= 0.05
if action == 2 and self.player_x < self.width - self.player_size:
self.player_x += move_speed
reward -= 0.05
self.frame_count += 1
level = self.score // 5
self.current_speed = 5 + level
self.spawn_rate = 30 - level * 2
spawn_rate = max(10, 30 - level)
if self.frame_count >= spawn_rate:
self.frame_count = 0
enemy_x = random.randint(0, self.width - self.enemy_size)
self.enemies.append([enemy_x, 0]) # [x, y]
for enemy in self.enemies:
enemy[1] += self.current_speed
player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)
enemy_rect = pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size)
if player_rect.colliderect(enemy_rect):
reward = -10
terminated = True
elif enemy[1] > self.height:
self.enemies.remove(enemy)
self.score += 1
reward = 1
if not terminated:
if self.score > 100:
reward += 0.01
reward += 0.01
obs = self._get_obs()
if self.render_mode == "human":
self._render_window()
return obs, reward, terminated, truncated, {}
展示游戏画面
下面完全是pygame代码,用于显示游戏画面,这里就不解释了。
def _render_window(self):
self.screen.blit(self.canvas, (0, 0))
text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
self.screen.blit(text, (10, 10))
pygame.display.flip()
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
你成功成为了一名调参侠了,快来试试吧!
如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~

浙公网安备 33010602011771号