划分数据集代码
数据集组织格式
├─annotations
│ ├─test
│ ├─train
│ └─val
├─images
├─test
├─train
└─val
分割脚本:
import os
import glob
import random
import json
from shutil import copy
# 划分数据集
if __name__ == '__main__':
# 原始数据路径
src_data_folder = './handlandmark_dataset'
# 目标数据路径
target_data_folder = './data/Dataset'
# 在目标目录下创建多级文件夹
first_dirs = ['images', 'annotations']
split_names = ['train', 'val', 'test']
for first_dir in first_dirs:
for split_name in split_names:
split_path = os.path.join(target_data_folder, first_dir, split_name)
if not os.path.exists(split_path):
os.makedirs(split_path)
# 根据图片划分数据集
imglist = glob.glob(os.path.join(src_data_folder, 'images', '*.png'))
random.shuffle(imglist)
print(len(imglist))
trainlist = imglist[:int(0.8*len(imglist))]
vallist = imglist[(int(0.8*len(imglist))):(int(0.9*len(imglist)))]
testlist = imglist[(int(0.9*len(imglist))):]
print(len(trainlist) + len(vallist) + len(testlist))
# 创建三者的json 字典
train_json = {}
val_json = {}
test_json = {}
# 读取 总的 json 文件, 并按 划分的图像将json 数据分别写入对应的json 文件中
with open(os.path.join(src_data_folder, 'annotations', 'data.json'), 'r') as f:
content = json.load(f)
# 利用items()方法遍历输出键-值
for key, value in content.items():
print('img_name:' + key)
src_img_path = os.path.join(src_data_folder, 'images', key)
# 将符合条件的内容分别写到对应的json 中
if src_img_path in trainlist:
train_json[key] = value
# 复制图片
copy(src_img_path, os.path.join(target_data_folder, 'images', 'train', key))
elif src_img_path in vallist:
val_json[key] = value
copy(src_img_path, os.path.join(target_data_folder, 'images', 'val', key))
else:
test_json[key] = value
copy(src_img_path, os.path.join(target_data_folder, 'images', 'test', key))
# 写 train val test json 文件
with open(os.path.join(target_data_folder, 'annotations', 'train', 'train_data.json'), 'w') as f_train:
json.dump(train_json, f_train)
with open(os.path.join(target_data_folder, 'annotations', 'val', 'val_data.json'), 'w') as f_val:
json.dump(train_json, f_val)
with open(os.path.join(target_data_folder, 'annotations', 'test', 'test_data.json'), 'w') as f_test:
json.dump(train_json, f_test)