1 import random
2 import os
3
4 random.seed(0)
5
6 source_path = '/data1/zjh/FFHQ/1024'
7 source_list = os.listdir(source_path)
8 divided_path = '/data1/zjh/FFHQ_divided'
9 if not os.path.exists(os.path.join(divided_path, 'train')):
10 os.makedirs(os.path.join(divided_path, 'train'))
11 if not os.path.exists(os.path.join(divided_path, 'val')):
12 os.makedirs(os.path.join(divided_path, 'val'))
13
14 eval_index = random.sample(source_list, k=int(70000 * 3//10))
15 for index, source_list_name in enumerate(source_list):
16 print(index)
17 # eval_index 中保存验证集val的图像名称
18 if source_list_name in eval_index:
19 os.system("cp %s %s" % (os.path.join(source_path, source_list_name), os.path.join(divided_path, 'val')))
20 else:
21 os.system("cp %s %s" % (os.path.join(source_path, source_list_name), os.path.join(divided_path, 'train')))