python脚本划分数据集
利用python脚本对文件夹中的大量文件划分训练集train、验证集val和测试集test。source_dir为源文件夹,source_dir目录中可以包含不同种类的文件夹。
import os
import shutil
import random
from pathlib import Path
def split_dataset(source_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
"""
将数据集按照指定比例分割为训练集、验证集和测试集
参数:
source_dir: 原始数据集目录
train_ratio: 训练集比例
val_ratio: 验证集比例
test_ratio: 测试集比例
"""
# 确保比例之和为1
assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "比例之和必须为1"
# 创建目标目录
train_dir = os.path.join(os.path.dirname(source_dir), "train")
val_dir = os.path.join(os.path.dirname(source_dir), "val")
test_dir = os.path.join(os.path.dirname(source_dir), "test")
for dir_path in [train_dir, val_dir, test_dir]:
if not os.path.exists(dir_path):
os.makedirs(dir_path)
print(f"创建目录: {dir_path}")
# 遍历源目录中的所有文件和子目录
for root, dirs, files in os.walk(source_dir):
# 跳过空目录
if not files:
continue
# 为当前目录在目标目录中创建相应的子目录结构
relative_path = os.path.relpath(root, source_dir)
for dir_path in [train_dir, val_dir, test_dir]:
target_dir = os.path.join(dir_path, relative_path)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
# 随机打乱文件顺序
random.shuffle(files)
total_files = len(files)
# 计算各集合的文件数量
train_count = int(total_files * train_ratio)
val_count = int(total_files * val_ratio)
# 测试集数量 = 剩余的文件
test_count = total_files - train_count - val_count
# 分配文件到各个集合
train_files = files[:train_count]
val_files = files[train_count:train_count + val_count]
test_files = files[train_count + val_count:]
# 复制文件到相应的目录
for file in train_files:
src = os.path.join(root, file)
dst = os.path.join(train_dir, relative_path, file)
shutil.copy2(src, dst)
for file in val_files:
src = os.path.join(root, file)
dst = os.path.join(val_dir, relative_path, file)
shutil.copy2(src, dst)
for file in test_files:
src = os.path.join(root, file)
dst = os.path.join(test_dir, relative_path, file)
shutil.copy2(src, dst)
print(f"处理目录: {relative_path}")
print(f" 训练集: {len(train_files)} 个文件")
print(f" 验证集: {len(val_files)} 个文件")
print(f" 测试集: {len(test_files)} 个文件")
print("数据集分割完成!")
if __name__ == "__main__":
# 设置源数据集目录
# 请将此处替换为你的原始数据集目录
source_directory = input("请输入原始数据集目录路径: ").strip()
# 检查源目录是否存在
if not os.path.isdir(source_directory):
print(f"错误: 目录 '{source_directory}' 不存在!")
else:
# 按7:2:1的比例分割数据集
split_dataset(source_directory, 0.7, 0.2, 0.1)
God will send the rain when you are ready.You need to prepare your field to receive it.