tensorflow2数据集的制作

tensorflow2后的数据的制作和tensorflow1有了很大的改变,建议直接使用tensorflow2数据集的制作方式
下面我记录下我的数据集制作过程

1.介绍下我的数据集结构


我这数据集是这样的Columbia_Gaze_Data_Set_224是顶层路径,其中0001``是子目录,子目录下包含一些图片 图片的命名格式这样0003_2m_0P_0V_0H.jpg```其中的P,V,H是分类的种类,这里是3种大类。

2.获取所有图片的文件名字

import tensorflow as tf
import pathlib
import random

AUTOTUNE = tf.data.experimental.AUTOTUNE
data_path="./data/Columbia_Gaze_Data_Set_224"
img_root=pathlib.Path(data_path)

#for it in img_root.iterdir():
#    print(it)
all_img_path=list(img_root.glob('*/*'))
#print(all_img_path[:])

all_img_paths=[str(path) for path in all_img_path]

将图片名字打乱,这一步可以不做,因为在制作成数据集后也可以再打乱

random.shuffle(all_img_paths)

3.获取根据图片路径获取所有的标签

place_label=list([[-30, -15, 0, 15, 30, -1, -1, -1],
             [-10, 0, 10, -1, -1, -1, -1, -1],
             [-15, -10, -5, 0, 5, 10, 15, -1]])

def GetLabelByImageName(img_name):
    P_index=img_name.find('P',8)
    P_val=img_name[8:P_index]
    V_index=img_name.find('V',(P_index+2))
    V_val=img_name[(P_index+2):V_index]
    H_index = img_name.find('H', (V_index + 2))
    H_val = img_name[(V_index + 2):H_index]
    return [place_label[0].index(int(P_val)),place_label[1].index(int(V_val)),place_label[2].index(int(H_val))]

def GetImgName(img_path):
    #print(img_path)
    index=img_path.rfind('\\')
    return img_path[index+1:]

#print(GetLabelByImageName(GetImgName(all_img_paths[0])))

#create all images labels
def CreateAllImageLabel(all_img_paths):
    labels=[]
    for item in all_img_paths:
        label=GetLabelByImageName(GetImgName(item))
        labels.append(label)
    return labels
all_img_labels=CreateAllImageLabel(all_img_paths)
#print(len(all_img_labels))
#print(all_img_labels[0])

4.定义图片预处理方法

def preprocess_img(img_path):
    image = tf.io.read_file(img_path)
    img=tf.image.decode_jpeg(image)
    img=tf.image.resize(img,[224,224])
    #img/=255.
    return img

5.制作成可用的数据集

#加载数据集
path_ds=tf.data.Dataset.from_tensor_slices(all_img_paths)
img_ds=path_ds.map(preprocess_img,num_parallel_calls=AUTOTUNE)

#加载数据集label
label_ds=tf.data.Dataset.from_tensor_slices(tf.cast(all_img_labels,tf.int64))

#将图片与label打包
img_label_ds=tf.data.Dataset.zip((img_ds,label_ds))

6.观察下制作效果

db=img_label_ds.shuffle(60000).batch(64)
sample_iter=iter(db)
sampe=next(sample_iter)
print(sampe[0].shape,sampe[1].shape)
posted @ 2021-04-16 18:01  cyssmile  阅读(572)  评论(0)    收藏  举报