import os
import numpy as np
from sklearn.model_selection import KFold
import json
############################ 读取img数据名列表 ###################################
data_path = r'./training_set'
index_list =[] # 存放所有图像名称
for each_img_name in sorted(os.listdir(data_path)): # each_img_name: 1.PNG
# print('each_img_name:',each_img_name)
if 'Annotation' in each_img_name.split('_')[-1]:
continue
img_index = each_img_name.split('_')[0] # img_index: 51
# print('img_index:',img_index)
index_list.append(img_index)
print('index_list:',index_list)
####################################################################################
###################### 划分交叉验证的数据集 ######################
k = 5 # 分成几分,即几折检验
kf =KFold(n_splits=k,shuffle= True,random_state=1)
num = 0
write_dic = {}
for train,test in kf.split(index_list):
train_index = np.array(index_list)[train]
test_index = np.array(index_list)[test]
print('train:',len(train_index))
print('test:', len(train_index))
write_dic[num+1] = {'train':list(train_index),'test':list(test_index)}
num +=1
save_path = '{}_split_index.txt'.format(k)
with open(save_path,'w+') as f:
f.write(json.dumps(write_dic)) # 字典转json json.loads json转字典
print()
########################## 测试读入 ###############################
# load_path = '{}_split_index.txt'.format(k)
# with open(load_path,'r') as f:
# data = f.read()
# data = json.loads(data)
# print(data)