http://blog.sina.com.cn/s/blog_3fe8bc880102wfex.html

tfrecords,制作,不是太理解

  (2016-08-28 07:19:33)
   

构建自己的图片数据集tfrecords。

   先贴我的转化代码将图片文件夹下的图片转存tfrecords的数据集。

 

[python] view plain copy
 
  1. ############################################################################################  
  2. #!/usr/bin/python2.7  
  3. # -*- coding: utf-8 -*-  
  4. #Author  : zhaoqinghui  
  5. #Date    : 2016.5.10  
  6. #Function: image convert to tfrecords   
  7. #############################################################################################  
  8.   
  9. import tensorflow as tf  
  10. import numpy as np  
  11. import cv2  
  12. import os  
  13. import os.path  
  14. from PIL import Image  
  15.   
  16. #参数设置  
  17. ###############################################################################################  
  18. train_file = 'train.txt' #训练图片  
  19. name='train'      #生成train.tfrecords  
  20. output_directory='./tfrecords'  
  21. resize_height=32 #存储图片高度  
  22. resize_width=32 #存储图片宽度  
  23. ###############################################################################################  
  24. def _int64_feature(value):  
  25.     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))  
  26.   
  27. def _bytes_feature(value):  
  28.     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  
  29.   
  30. def load_file(examples_list_file):  
  31.     lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1', 'S120'), ('col2', 'i8')])  
  32.     examples = []  
  33.     labels = []  
  34.     for example, label in lines:  
  35.         examples.append(example)  
  36.         labels.append(label)  
  37.     return np.asarray(examples), np.asarray(labels), len(lines)  
  38.   
  39. def extract_image(filename,  resize_height, resize_width):  
  40.     image = cv2.imread(filename)  
  41.     image = cv2.resize(image, (resize_height, resize_width))  
  42.     b,g,r = cv2.split(image)         
  43.     rgb_image = cv2.merge([r,g,b])       
  44.     return rgb_image  
  45.   
  46. def transform2tfrecord(train_file, name, output_directory, resize_height, resize_width):  
  47.     if not os.path.exists(output_directory) or os.path.isfile(output_directory):  
  48.         os.makedirs(output_directory)  
  49.     _examples, _labels, examples_num = load_file(train_file)  
  50.     filename = output_directory + "/" + name + '.tfrecords'  
  51.     writer = tf.python_io.TFRecordWriter(filename)  
  52.     for i, [example, label] in enumerate(zip(_examples, _labels)):  
  53.         print('No.%d' % (i))  
  54.         image = extract_image(example, resize_height, resize_width)  
  55.         print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label))  
  56.         image_raw = image.tostring()  
  57.         example = tf.train.Example(features=tf.train.Features(feature={  
  58.             'image_raw': _bytes_feature(image_raw),  
  59.             'height': _int64_feature(image.shape[0]),  
  60.             'width': _int64_feature(image.shape[1]),  
  61.             'depth': _int64_feature(image.shape[2]),  
  62.             'label': _int64_feature(label)  
  63.         }))  
  64.         writer.write(example.SerializeToString())  
  65.     writer.close()  
  66.   
  67. def disp_tfrecords(tfrecord_list_file):  
  68.     filename_queue = tf.train.string_input_producer([tfrecord_list_file])  
  69.     reader = tf.TFRecordReader()  
  70.     _, serialized_example = reader.read(filename_queue)  
  71.     features = tf.parse_single_example(  
  72.         serialized_example,  
  73.  features={  
  74.           'image_raw': tf.FixedLenFeature([], tf.string),  
  75.           'height': tf.FixedLenFeature([], tf.int64),  
  76.           'width': tf.FixedLenFeature([], tf.int64),  
  77.           'depth': tf.FixedLenFeature([], tf.int64),  
  78.           'label': tf.FixedLenFeature([], tf.int64)  
  79.       }  
  80.     )  
  81.     image = tf.decode_raw(features['image_raw'], tf.uint8)  
  82.     #print(repr(image))  
  83.     height = features['height']  
  84.     width = features['width']  
  85.     depth = features['depth']  
  86.     label = tf.cast(features['label'], tf.int32)  
  87.     init_op = tf.initialize_all_variables()  
  88.     resultImg=[]  
  89.     resultLabel=[]  
  90.     with tf.Session() as sess:  
  91.         sess.run(init_op)  
  92.         coord = tf.train.Coordinator()  
  93.         threads = tf.train.start_queue_runners(sess=sess, coord=coord)  
  94.         for i in range(21):  
  95.             image_eval = image.eval_r()  
  96.             resultLabel.append(label.eval_r())  
  97.             image_eval_reshape = image_eval.reshape([height.eval_r(), width.eval_r(), depth.eval_r()])  
  98.             resultImg.append(image_eval_reshape)  
  99.             pilimg = Image.fromarray(np.asarray(image_eval_reshape))  
  100.             pilimg.show()  
  101.         coord.request_stop()  
  102.         coord.join(threads)  
  103.         sess.close()  
  104.     return resultImg,resultLabel  
  105.   
  106. def read_tfrecord(filename_queuetemp):  
  107.     filename_queue = tf.train.string_input_producer([filename_queuetemp])  
  108.     reader = tf.TFRecordReader()  
  109.     _, serialized_example = reader.read(filename_queue)  
  110.     features = tf.parse_single_example(  
  111.         serialized_example,  
  112.         features={  
  113.           'image_raw': tf.FixedLenFeature([], tf.string),  
  114.           'width': tf.FixedLenFeature([], tf.int64),  
  115.           'depth': tf.FixedLenFeature([], tf.int64),  
  116.           'label': tf.FixedLenFeature([], tf.int64)  
  117.       }  
  118.     )  
  119.     image = tf.decode_raw(features['image_raw'], tf.uint8)  
  120.     # image  
  121.     tf.reshape(image, [256, 256, 3])  
  122.     # normalize  
  123.     image = tf.cast(image, tf.float32) * (1. /255) - 0.5  
  124.     # label  
  125.     label = tf.cast(features['label'], tf.int32)  
  126.     return image, label  
  127.   
  128. def test():  
  129.     transform2tfrecord(train_file, name , output_directory,  resize_height, resize_width) #转化函数     
  130.     img,label=disp_tfrecords(output_directory+'/'+name+'.tfrecords') #显示函数  
  131.     img,label=read_tfrecord(output_directory+'/'+name+'.tfrecords') #读取函数  
  132.     print label  
  133.   
  134. if __name__ == '__main__':  
  135.     test()  
tfrecords,制作,不是太理解


 

这样就可以得到自己专属的数据集.tfrecords了  ,它可以直接用于tensorflow的数据集。