python脚本划分数据集

利用python脚本对文件夹中的大量文件划分训练集train、验证集val和测试集test。source_dir为源文件夹,source_dir目录中可以包含不同种类的文件夹。

import os
import shutil
import random
from pathlib import Path

def split_dataset(source_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
    """
    将数据集按照指定比例分割为训练集、验证集和测试集

    参数:
        source_dir: 原始数据集目录
        train_ratio: 训练集比例
        val_ratio: 验证集比例
        test_ratio: 测试集比例
    """
    # 确保比例之和为1
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "比例之和必须为1"

    # 创建目标目录
    train_dir = os.path.join(os.path.dirname(source_dir), "train")
    val_dir = os.path.join(os.path.dirname(source_dir), "val")
    test_dir = os.path.join(os.path.dirname(source_dir), "test")

    for dir_path in [train_dir, val_dir, test_dir]:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
            print(f"创建目录: {dir_path}")

    # 遍历源目录中的所有文件和子目录
    for root, dirs, files in os.walk(source_dir):
        # 跳过空目录
        if not files:
            continue

        # 为当前目录在目标目录中创建相应的子目录结构
        relative_path = os.path.relpath(root, source_dir)
        for dir_path in [train_dir, val_dir, test_dir]:
            target_dir = os.path.join(dir_path, relative_path)
            if not os.path.exists(target_dir):
                os.makedirs(target_dir)

        # 随机打乱文件顺序
        random.shuffle(files)
        total_files = len(files)

        # 计算各集合的文件数量
        train_count = int(total_files * train_ratio)
        val_count = int(total_files * val_ratio)
        # 测试集数量 = 剩余的文件
        test_count = total_files - train_count - val_count

        # 分配文件到各个集合
        train_files = files[:train_count]
        val_files = files[train_count:train_count + val_count]
        test_files = files[train_count + val_count:]

        # 复制文件到相应的目录
        for file in train_files:
            src = os.path.join(root, file)
            dst = os.path.join(train_dir, relative_path, file)
            shutil.copy2(src, dst)

        for file in val_files:
            src = os.path.join(root, file)
            dst = os.path.join(val_dir, relative_path, file)
            shutil.copy2(src, dst)

        for file in test_files:
            src = os.path.join(root, file)
            dst = os.path.join(test_dir, relative_path, file)
            shutil.copy2(src, dst)

        print(f"处理目录: {relative_path}")
        print(f"  训练集: {len(train_files)} 个文件")
        print(f"  验证集: {len(val_files)} 个文件")
        print(f"  测试集: {len(test_files)} 个文件")

    print("数据集分割完成!")


if __name__ == "__main__":
    # 设置源数据集目录
    # 请将此处替换为你的原始数据集目录
    source_directory = input("请输入原始数据集目录路径: ").strip()

    # 检查源目录是否存在
    if not os.path.isdir(source_directory):
        print(f"错误: 目录 '{source_directory}' 不存在!")
    else:
        # 按7:2:1的比例分割数据集
        split_dataset(source_directory, 0.7, 0.2, 0.1)
posted @ 2025-09-23 15:27  Steven0325  阅读(17)  评论(0)    收藏  举报