stable_baseline3 快速入门(二): 训练自定义游戏,构建Gymnasium训练环境

简介

Gymnasium 为强化学习提供了一个标准化的API,它定义了 Agent 应该如何观察世界、如何做出动作以及如何获得奖励,不管是游戏,还是工业设备,只需要满足Gymnasium标准都能使用同一套代码进行训练。

认识Gymnasium

使用stable_baseline3只需要定义好Gymnasium环境,关注训练的奖励机制,将重点放在业务的开发上而不是复杂的算法。

Gymnasium提供了几个核心的api:

方法 功能 返回值
reset() 将环境重置为初始状态,开始新回合。 obs, info
step(action) 环境向前推进一步,执行动作。 obs, reward, terminated, truncated, info
render() 可视化环境(根据 render_mode 渲染图像或弹出窗口)。 视配置而定(通常无或为 np.array
close() 释放环境资源(关闭窗口、清理内存)。

其中的各个返回值的含义:

  • observation (Object): 当前状态的描述。例如敌人,玩家的位置,玩家的状态等
  • reward (Float): 上一步动作获得的奖励
  • terminated (Bool): 是否由于任务逻辑结束。例如:到达终点、掉进岩浆等
  • truncated (Bool): 是否由于外部限制结束。例如:达到最大步数 500 步
  • info (Dict): 辅助诊断信息,模型训练通常不用,用于用户自定义调试或记录额外统计。

手动构建环境

案例

案例描述:利用pygame构建一个简单的游戏,躲避掉落方块,利用构建的奖励机制,进行强化学习。

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pygame
import random
import cv2
import os
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_checker import check_env

class MyEnv(gym.Env):
    def __init__(self, render_mode=None):
        super(MyEnv, self).__init__()

        #初始化参数
        self.width = 400
        self.height = 300
        self.player_size = 30
        self.enemy_size = 30
        self.render_mode = render_mode

        self.action_space = spaces.Discrete(3)

        self.observation_space = spaces.Box(
            low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
        )

        pygame.init()
        if self.render_mode == "human":
            self.screen = pygame.display.set_mode((self.width, self.height))
        
        self.canvas = pygame.Surface((self.width, self.height))
        self.font = pygame.font.SysFont("monospace", 15)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        self.player_x = self.width // 2 - self.player_size // 2
        self.player_y = self.height - self.player_size - 10
        self.enemies = []
        self.score = 0
        self.frame_count = 0

        self.current_speed = 5
        self.spawn_rate = 30

        return self._get_obs(), {}

    def step(self, action):
        reward = 0
        terminated = False
        truncated = False

        move_speed = 8
        if action == 1 and self.player_x > 0: # 
            self.player_x -= move_speed
            reward -= 0.05

        if action == 2 and self.player_x < self.width - self.player_size:
            self.player_x += move_speed
            reward -= 0.05

        self.frame_count += 1

        level = self.score // 5
        self.current_speed = 5 + level
        self.spawn_rate = 30 - level * 2
        spawn_rate = max(10, 30 - level)

        if self.frame_count >= spawn_rate:
            self.frame_count = 0
            enemy_x = random.randint(0, self.width - self.enemy_size)
            self.enemies.append([enemy_x, 0]) # [x, y]

        for enemy in self.enemies:
            enemy[1] += self.current_speed
            
            player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)
            enemy_rect = pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size)
            
            if player_rect.colliderect(enemy_rect):
                reward = -10 
                terminated = True

            elif enemy[1] > self.height:
                self.enemies.remove(enemy)
                self.score += 1
                reward = 1 
        
        if not terminated:
            if self.score > 100:
                reward += 0.01
            reward += 0.01

        obs = self._get_obs()

        if self.render_mode == "human":
            self._render_window()

        return obs, reward, terminated, truncated, {}

    def _get_obs(self):
        self.canvas.fill((0, 0, 0))
        pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))
        
        for enemy in self.enemies:
            pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size))

        img_array = pygame.surfarray.array3d(self.canvas)
        img_array = np.transpose(img_array, (1, 0, 2))
        obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)

        return obs.astype(np.uint8)

    def _render_window(self):
        self.screen.blit(self.canvas, (0, 0))
        text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
        self.screen.blit(text, (10, 10))
        pygame.display.flip()

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

def train():
    log_dir = "logs/DodgeGame"
    os.makedirs(log_dir, exist_ok=True)

    env = MyEnv()
    check_env(env)
    print("环境检查通过...")

    model_path = "models/dodge_ai.zip"
    if not os.path.exists(model_path):
        print("🆕 未发现旧模型,从头开始训练...")
        model = PPO(
            "CnnPolicy", 
            env, 
            verbose=1,
            tensorboard_log=log_dir,
            learning_rate=0.0001,
            n_steps=4096,
            batch_size=256,
            device="cuda")
        reset_timesteps = True
    else:
        print("发现旧模型,加载并继续训练...")
        model = PPO.load(
            model_path, 
            env=env,      
            device="cuda",
            custom_objects={"learning_rate": 0.0001, "n_steps": 4096, "batch_size": 256}
        )
        reset_timesteps = False
   
    print("开始训练...")

    model.learn(
        total_timesteps=50000,
        reset_num_timesteps=reset_timesteps
    )

    model.save("models/dodge_ai")
    print("模型已保存!")
    env.close()

def prodict():
    env = MyEnv(render_mode="human")
    model = PPO.load("models/dodge_ai", env=env, device="cuda")
    obs, _ = env.reset()

    while True:
        action, _states = model.predict(obs, deterministic=True)

        obs, reward, terminated, truncated, info = env.step(action)

        if terminated or truncated:
            obs, _ = env.reset()
        
        pygame.time.Clock().tick(30)

if __name__ == "__main__":
    train()

    prodict()

代码解析

代码流程如下:
构建游戏环境->训练模型->模型预测
本篇重点讲构建游戏环境,其中的pygame相关代码简略,另外两个流程参考之前文章。

构建游戏环境

初始化类

该类继承gym.Env

class MyEnv(gym.Env):
构造函数__init__
def __init__(self, render_mode=None):
        super(MyEnv, self).__init__()

        #初始化参数
        self.width = 400
        self.height = 300
        self.player_size = 30
        self.enemy_size = 30
        self.render_mode = render_mode

        self.action_space = spaces.Discrete(3)

        self.observation_space = spaces.Box(
            low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
        )

        pygame.init()
        if self.render_mode == "human":
            self.screen = pygame.display.set_mode((self.width, self.height))
        
        self.canvas = pygame.Surface((self.width, self.height))
        self.font = pygame.font.SysFont("monospace", 15)

在构造函数中,我们主要完成的是声明训练的维度,和输入:

  • 输入:self.action_space = spaces.Discrete(3)其中的self.action_space固定名称的父类变量spaces.Discrete(3)声明输入的数量,例如:向左 向右 和 不动3个输入。
  • 观测维度:self.observation_space也是固定名称的父类变量spaces.Box声明观测维度。
self.observation_space = spaces.Box(
    low=0, high=255, shape=(84, 84, 3), dtype=np.uint8
)
  1. low:观测参数的最小值
  2. high:观测参数的最大值
  3. shape:声明维度。例如:观测图片shape(高,宽,RGB),观测一个平面,shape(高,宽)
  4. dtype:每个变量类型,这里选np.uint8能够节省训练成本,默认是浮点型的。
任务重置 reset

相当于初始化游戏状态,游戏的重新开始。返回的是观测值状态信息(用于调试日志)

def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        self.player_x = self.width // 2 - self.player_size // 2
        self.player_y = self.height - self.player_size - 10
        self.enemies = []
        self.score = 0
        self.frame_count = 0

        self.current_speed = 5
        self.spawn_rate = 30

        return self._get_obs(), {}

观测值 _get_obs
通过pygame画出的画面,然后用opencv进行简单处理:

  1. 转换坐标轴(由于opencv坐标xy轴跟pygame的xy是颠倒的)
  2. 将画面缩放到84 * 84(可以提高训练效率)
def _get_obs(self):
        self.canvas.fill((0, 0, 0))
        pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size))
        
        for enemy in self.enemies:
            pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size))

        img_array = pygame.surfarray.array3d(self.canvas)
        img_array = np.transpose(img_array, (1, 0, 2))
        obs = cv2.resize(img_array, (84, 84), interpolation=cv2.INTER_AREA)

        return obs.astype(np.uint8)
步 step(重要)

这个函数是强化训练的核心,规定了在一帧或者一步,我们给AI的分数。
分数的设置至关重要,这直接决定了训练出来AI的质量
根据下面代码(大部分都是游戏逻辑),主要讲设置奖励分数

  1. 在AI进行移动时 惩罚 0.05 分
  2. 在AI存活时 奖励 0.01分,游戏分数大于100时 存活奖励 0.02分
  3. 在障碍物完全下落时 奖励 1 分
  4. 在与障碍物碰撞时 惩罚 10 分
def step(self, action):
        reward = 0
        terminated = False
        truncated = False

        move_speed = 8
        if action == 1 and self.player_x > 0: # 
            self.player_x -= move_speed
            reward -= 0.05

        if action == 2 and self.player_x < self.width - self.player_size:
            self.player_x += move_speed
            reward -= 0.05

        self.frame_count += 1

        level = self.score // 5
        self.current_speed = 5 + level
        self.spawn_rate = 30 - level * 2
        spawn_rate = max(10, 30 - level)

        if self.frame_count >= spawn_rate:
            self.frame_count = 0
            enemy_x = random.randint(0, self.width - self.enemy_size)
            self.enemies.append([enemy_x, 0]) # [x, y]

        for enemy in self.enemies:
            enemy[1] += self.current_speed
            
            player_rect = pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size)
            enemy_rect = pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size)
            
            if player_rect.colliderect(enemy_rect):
                reward = -10 
                terminated = True

            elif enemy[1] > self.height:
                self.enemies.remove(enemy)
                self.score += 1
                reward = 1 
        
        if not terminated:
            if self.score > 100:
                reward += 0.01
            reward += 0.01

        obs = self._get_obs()

        if self.render_mode == "human":
            self._render_window()

        return obs, reward, terminated, truncated, {}
展示游戏画面

下面完全是pygame代码,用于显示游戏画面,这里就不解释了。

def _render_window(self):
        self.screen.blit(self.canvas, (0, 0))
        text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
        self.screen.blit(text, (10, 10))
        pygame.display.flip()

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()

你成功成为了一名调参侠了,快来试试吧!

如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~

posted @ 2026-01-30 17:46  ClownLMe  阅读(1)  评论(0)    收藏  举报