tensorflow2数据集的制作
tensorflow2后的数据的制作和tensorflow1有了很大的改变,建议直接使用tensorflow2数据集的制作方式
下面我记录下我的数据集制作过程
1.介绍下我的数据集结构

我这数据集是这样的Columbia_Gaze_Data_Set_224是顶层路径,其中0001``是子目录,子目录下包含一些图片 图片的命名格式这样0003_2m_0P_0V_0H.jpg```其中的P,V,H是分类的种类,这里是3种大类。
2.获取所有图片的文件名字
import tensorflow as tf
import pathlib
import random
AUTOTUNE = tf.data.experimental.AUTOTUNE
data_path="./data/Columbia_Gaze_Data_Set_224"
img_root=pathlib.Path(data_path)
#for it in img_root.iterdir():
# print(it)
all_img_path=list(img_root.glob('*/*'))
#print(all_img_path[:])
all_img_paths=[str(path) for path in all_img_path]
将图片名字打乱,这一步可以不做,因为在制作成数据集后也可以再打乱
random.shuffle(all_img_paths)
3.获取根据图片路径获取所有的标签
place_label=list([[-30, -15, 0, 15, 30, -1, -1, -1],
[-10, 0, 10, -1, -1, -1, -1, -1],
[-15, -10, -5, 0, 5, 10, 15, -1]])
def GetLabelByImageName(img_name):
P_index=img_name.find('P',8)
P_val=img_name[8:P_index]
V_index=img_name.find('V',(P_index+2))
V_val=img_name[(P_index+2):V_index]
H_index = img_name.find('H', (V_index + 2))
H_val = img_name[(V_index + 2):H_index]
return [place_label[0].index(int(P_val)),place_label[1].index(int(V_val)),place_label[2].index(int(H_val))]
def GetImgName(img_path):
#print(img_path)
index=img_path.rfind('\\')
return img_path[index+1:]
#print(GetLabelByImageName(GetImgName(all_img_paths[0])))
#create all images labels
def CreateAllImageLabel(all_img_paths):
labels=[]
for item in all_img_paths:
label=GetLabelByImageName(GetImgName(item))
labels.append(label)
return labels
all_img_labels=CreateAllImageLabel(all_img_paths)
#print(len(all_img_labels))
#print(all_img_labels[0])
4.定义图片预处理方法
def preprocess_img(img_path):
image = tf.io.read_file(img_path)
img=tf.image.decode_jpeg(image)
img=tf.image.resize(img,[224,224])
#img/=255.
return img
5.制作成可用的数据集
#加载数据集
path_ds=tf.data.Dataset.from_tensor_slices(all_img_paths)
img_ds=path_ds.map(preprocess_img,num_parallel_calls=AUTOTUNE)
#加载数据集label
label_ds=tf.data.Dataset.from_tensor_slices(tf.cast(all_img_labels,tf.int64))
#将图片与label打包
img_label_ds=tf.data.Dataset.zip((img_ds,label_ds))
6.观察下制作效果
db=img_label_ds.shuffle(60000).batch(64)
sample_iter=iter(db)
sampe=next(sample_iter)
print(sampe[0].shape,sampe[1].shape)

浙公网安备 33010602011771号