从0到1,快速训练并使用YOLO模型

简介

YOLO是目前计算机视觉领域最前沿、应用最广泛的目标检测算法框架,他能快速识别区分目标,广泛应用于游戏,无人驾驶,工业等领域。

以识别躲避掉落滑块的游戏的物体图片作为例子。

一,环境配置

pip install ultralytics

二,准备数据集

这个格式目录如下:

my_dataset/
├── data.yaml # 配置文件(定义路径和类别)
├── train/ #训练数据集
│   ├── images/ # 训练图片
│   └── labels/ # 标注文件 (.txt)
└── val/ #验证数据集
    ├── images/ 
    └── labels/ 

data.yml

path: D:\D_MyProject\Ai\game_ai\my_dataset #数据集路径
train: train/images #训练集图片路径
val: val/images #验证集图片路径

nc: 3 #标记个数
names:  #每个标记的名称
  0: player
  1: enemy
  2: game_over

下面是用AI生成了数据集的生成脚本

数据量太多了,这里为了演示,或者学习,可以直接使用下面脚本

import pygame
import random
import sys
import os
import shutil

# =================配置区域=================
# 数据集根目录名称
DATASET_ROOT = "my_dataset"
# 采集总数量
MAX_IMAGES = 1000
# 训练集占比 (0.8 = 80% 训练, 20% 验证)
TRAIN_RATIO = 0.8

# ================= 1. 环境清理与目录创建 =================
print(f"🚀 正在初始化数据集目录: {DATASET_ROOT} ...")

# 如果目录已存在,先删除(防止旧数据混入),确保数据纯净
if os.path.exists(DATASET_ROOT):
    shutil.rmtree(DATASET_ROOT)

# 创建 YOLO 标准目录结构
for split in ['train', 'val']:
    os.makedirs(os.path.join(DATASET_ROOT, split, 'images'), exist_ok=True)
    os.makedirs(os.path.join(DATASET_ROOT, split, 'labels'), exist_ok=True)

# ================= 2. 自动生成 data.yaml =================
yaml_content = f"""
path: {os.path.abspath(DATASET_ROOT)} # 使用绝对路径,防止报错
train: train/images
val: val/images

nc: 3
names:
  0: player
  1: enemy
  2: game_over
"""
with open(os.path.join(DATASET_ROOT, "data.yaml"), "w", encoding='utf-8') as f:
    f.write(yaml_content)
print("✅ data.yaml 配置文件已生成。")

# ================= 3. 游戏与采集逻辑 =================
pygame.init()
WIDTH, HEIGHT = 800, 600
# 使用 hidden 模式或正常模式均可,这里用正常模式方便你看到进度
screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Auto Data Generator")
clock = pygame.time.Clock()

font_big = pygame.font.SysFont("monospace", 50)
font_small = pygame.font.SysFont("monospace", 35)

# 游戏状态
player_size, enemy_size = 50, 50
player_x = WIDTH // 2
player_y = HEIGHT - player_size - 10
enemies = []

img_count = 0
GAMEPLAY_LIMIT = int(MAX_IMAGES * 0.9) # 90% 正常游戏,10% Game Over

print(f"📸 开始采集 {MAX_IMAGES} 张图片 (自动划分 Train/Val)...")

while img_count < MAX_IMAGES:
    # 处理退出事件,防止卡死
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            pygame.quit()
            sys.exit()

    screen.fill((0, 0, 0))
    labels = [] # 存储当前帧的标签
    dw, dh = 1.0 / WIDTH, 1.0 / HEIGHT # 归一化系数

    # --- 逻辑分支:正常游戏 vs Game Over ---
    if img_count < GAMEPLAY_LIMIT:
        # A. 正常游戏画面
        # 1. 玩家移动
        if random.random() < 0.2: # 增加移动频率
            player_x += random.choice([-15, 15])
            player_x = max(0, min(WIDTH - player_size, player_x))

        # 2. 敌人生成与移动
        if random.randint(0, 20) == 0: # 增加敌人密度
            enemies.append([random.randint(0, WIDTH - enemy_size), 0])
        
        for enemy in enemies: enemy[1] += 15 # 加快下落速度
        enemies = [e for e in enemies if e[1] < HEIGHT]

        # 3. 绘制玩家 (Class 0)
        pygame.draw.rect(screen, (50, 150, 255), (player_x, player_y, player_size, player_size))
        # 计算 YOLO 坐标 (class x_center y_center width height)
        px, py = (player_x + player_size/2) * dw, (player_y + player_size/2) * dh
        labels.append(f"0 {px:.6f} {py:.6f} {player_size*dw:.6f} {player_size*dh:.6f}")

        # 4. 绘制敌人 (Class 1)
        for e in enemies:
            pygame.draw.rect(screen, (255, 50, 50), (e[0], e[1], enemy_size, enemy_size))
            ex, ey = (e[0] + enemy_size/2) * dw, (e[1] + enemy_size/2) * dh
            labels.append(f"1 {ex:.6f} {ey:.6f} {enemy_size*dw:.6f} {enemy_size*dh:.6f}")

    else:
        # B. Game Over 画面 (Class 2)
        text_surf = font_big.render("GAME OVER", True, (255, 50, 50))
        
        # 随机抖动位置,防止过拟合
        off_x, off_y = random.randint(-50, 50), random.randint(-50, 50)
        text_x = WIDTH // 2 - text_surf.get_width() // 2 + off_x
        text_y = HEIGHT // 2 - 100 + off_y
        screen.blit(text_surf, (text_x, text_y))

        # 干扰项
        score_surf = font_small.render(f"Score: {random.randint(0,999)}", True, (255,255,255))
        screen.blit(score_surf, (WIDTH//2 - score_surf.get_width()//2, HEIGHT//2 + 50))

        # 记录标签
        tw, th = text_surf.get_width(), text_surf.get_height()
        tx, ty = (text_x + tw/2) * dw, (text_y + th/2) * dh
        labels.append(f"2 {tx:.6f} {ty:.6f} {tw*dw:.6f} {th*dh:.6f}")

    # ================= 4. 保存逻辑 (核心修改) =================
    pygame.display.flip()
    
    # 采样率:不是每一帧都保存,防止重复度过高 (这里设为 30% 概率保存)
    if random.random() < 0.3:
        # A. 决定是去 Train 还是 Val
        split_folder = "train" if random.random() < TRAIN_RATIO else "val"
        
        # B. 生成文件名
        filename = f"{img_count:06d}"
        img_save_path = os.path.join(DATASET_ROOT, split_folder, "images", f"{filename}.jpg")
        lbl_save_path = os.path.join(DATASET_ROOT, split_folder, "labels", f"{filename}.txt")

        # C. 保存图片
        pygame.image.save(screen, img_save_path)

        # D. 保存标签
        with open(lbl_save_path, "w") as f:
            f.write("\n".join(labels))

        img_count += 1
        
        # 打印进度条
        print(f"[{split_folder.upper()}] 进度: {img_count}/{MAX_IMAGES}", end="\r")

    # 加速模拟,不要垂直同步,越快越好
    clock.tick(0) 

pygame.quit()
print(f"\n\n✨ 全部完成!数据集已就绪:{os.path.abspath(DATASET_ROOT)}")
print("💡 下一步:直接运行 model.train(data='my_dataset/data.yaml')")

三,训练YOLO模型

可以看到,使用ultralytics框架训练YOLO的代码非常简单,只需要几行
注意:这里YOLO会自动下载模型并训练,下载时失败可能需要挂梯子

from ultralytics import YOLO

def train_model():
    #加载模型
    model = YOLO("yolo11n.pt") 

    #开始训练
    print("开始训练...")
    results = model.train(
        data="my_dataset/data.yaml", #数据集配置文件
        epochs=30, #训练轮数
        imgsz=640, #图片输入尺寸
        batch=16, #显存够大可以改大,比如 32 或 64
        device=0, #强制使用第一张显卡 (需要CUDA)
        workers=0, #Windows下设为0防止多进程报错
        project="dodge_project",#保存路径
        name="ai_model" #训练运行名称
    )
    print(f"训练完成!最佳模型保存在: {results.save_dir}/weights/best.pt")

if __name__ == "__main__":
    train_model()

四,使用YOLO模型

from ultralytics import YOLO

#配置路径
MODEL_PATH = "dodge_project/ai_model/weights/best.pt" 
PIC_PATH = "my_dataset/val/images/000045.jpg"

#加载模型
model = YOLO(MODEL_PATH)
print(f"模型加载成功")

#识别图片
results = model.predict(PIC_PATH, verbose=False, conf=0.4, imgsz=640)
result = results[0]

#获取信息并打印
for box in result.boxes:
    # 获取坐标 (x1, y1, x2, y2)、类别 ID 和置信度
    x1, y1, x2, y2 = map(int, box.xyxy[0])
    cls_id = int(box.cls[0])
    conf = float(box.conf[0])
    
    # 类别名称映射
    names = {0: "Player", 1: "Enemy", 2: "Game Over"}
    label = names.get(cls_id, f"Unknown({cls_id})")
    print(f"找到物体: [{label:10}] | 置信度: {conf:.2f} | 坐标: ({x1}, {y1}) -> ({x2}, {y2})")

#展示图片
result.show()

如果❤喜欢❤本系列教程,就点个关注吧,后续不定期更新~

posted @ 2026-01-31 21:54  ClownLMe  阅读(30)  评论(0)    收藏  举报