使用tensorflow中的Dataset来读取制作好的tfrecords文件

上一篇我写了如何给自己的图像集制作tfrecords文件,现在我们就来讲讲如何读取已经创建好的文件,我们使用的是Tensorflow中的Dataset来读取我们的tfrecords,网上很多帖子应该是很久之前的了,绝大多数的做法是,先将tfrecords序列化成一个队列,然后使用TFRecordReader这个函数进行解析,解析出来的每一行都是一个record,然后再将每一个record进行还原,但是这个函数你在使用的时候会报出异常,原因就是它已经被dataset中新的读取方式所替代,下个版本中可能就无法使用了,因此不建议大家使用这个函数,好了,下面就来看看是如何进行读取的吧。

 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 3 
 4 #定义可以一次获得多张图像的函数
 5 def show_image(image_dir):
 6     plt.imshow(image_dir)
 7     plt.axis('on')
 8     plt.show()
 9 
10 #单个record的解析函数
11 def decode_example(example):#,resize_height,resize_width,labels_nums):
12     features=tf.io.parse_single_example(example,features={
13         'image_raw':tf.io.FixedLenFeature([],tf.string),
14         'label':tf.io.FixedLenFeature([],tf.int64)
15     })
16     tf_image=tf.decode_raw(features['image_raw'],tf.uint8)#这个其实就是图像的像素模式,之前我们使用矩阵来表示图像
17     tf_image=tf.reshape(tf_image,shape=[224,224,3])#对图像的尺寸进行调整,调整成三通道图像
18     tf_image=tf.cast(tf_image,tf.float32)*(1./255)#对图像进行归一化以便保持和原图像有相同的精度
19     tf_label=tf.cast(features['label'],tf.int32)
20     tf_label=tf.one_hot(tf_label,5,on_value=1,off_value=0)#将label转化成用one_hot编码的格式
21     return tf_image,tf_label
22 
23 def batch_test(tfrecords_file):
24     dataset=tf.data.TFRecordDataset(tfrecords_file)
25     dataset=dataset.map(decode_example)
26     dataset=dataset.shuffle(100).batch(4)
27     iterator=tf.compat.v1.data.make_one_shot_iterator(dataset)
28     batch_images,batch_labels=iterator.get_next()
29 
30     init_op=tf.compat.v1.global_variables_initializer()
31     with tf.compat.v1.Session() as sess:
32         sess.run(init_op)
33         coord=tf.train.Coordinator()
34         threads=tf.train.start_queue_runners(coord=coord)
35         for i in range(4):
36             images,labels=sess.run([batch_images,batch_labels])
37             show_image(images[1,:,:,:])
38             print('shape:{},tpye:{},labels:{}'.format(images.shape, images.dtype, labels))
39 
40         coord.request_stop()
41         coord.join(threads)
42 
43 if __name__=='__main__':
44     tfrecords_file='D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/train.tfrecords'
45     resize_height=224
46     resize_width=224
47     batch_test(tfrecords_file)

我为了测试,写了batch_test这个函数,因为我想试一试看我做的tfrecords能不能被解析成功,如果你不想测试只想训练,那你直接把images_batch,和labels_batch放到网络中进行训练就可以了,还有一点要注意的,tf.global_variables_initializer()已经被tf.compat.v1.global_variables_initializer()所取代了,我做的时候不知道所以报了一个warning提示,同时tf.Sesssion()已经被tf.compat.v1.Session() 所替代,iterator=dataset.make_one_shot_iterator()已经被tf.compat.v1.data.make_one_shot_iterator(dataset)  所代替,这些异常要注意,然后我只是将每个batch的第二张图片显示出来了,你也可以显示其他的,但是意义不大,反正只是测试一下解析成功与否,成功了我们就不需要纠结别的了。好啦,就是这样,接下来我会把这些东西放到网络中进行训练,再更新我的学习,就酱。

posted @ 2019-09-02 09:10  daremosiranaihana  阅读(1526)  评论(0编辑  收藏  举报