bubbleeee

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::
  • 二进制文件
    • 包含多个tf.train.Example
      • Example是protocol buffer数据标准实现,包含一系列tf.train.feature属性
        • feature是key(string)-value(bytes_list || float_list ||  int64_list)键值对
  • 合理存储,二进制编码,加快数据读取和预处理速度
  • 数据转为tfrecord文件
writer = tf.python_io.TFRecordWriter(out_file_name)  # 1. 定义 writer对象

for data in dataes:
    context = dataes[0]
    question = dataes[1]
    answer = dataes[2]

    """ 2. 定义features """
    example = tf.train.Example(
        features = tf.train.Features(
            feature = {
               'context': tf.train.Feature(
                 int64_list=tf.train.Int64List(value=context)),
               'question': tf.train.Feature(
                 int64_list=tf.train.Int64List(value=question)),
               'answer': tf.train.Feature(
                 int64_list=tf.train.Int64List(value=answer))
            }))

    """ 3. 序列化,写入"""
    serialized = example.SerializeToString()
    writer.write(serialized)

 

  • tfrecord文件读取
从tfrecord文件创建TFRecordDataset
dataset = tf.data.TFRecordDataset('xxx.tfrecord')

解析tfrecord文件的每条记录,即序列化后的tf.train.Example;使用tf.parse_single_example来解析:

feats = tf.parse_single_example(serial_exmp, features=data_dict)

其中,data_dict是一个dict,包含的key是写入tfrecord文件时用的key,相应的value则是tf.FixedLenFeature([], tf.string)、tf.FixedLenFeature([], tf.int64)、tf.FixedLenFeature([], tf.float32),分别对应不同的数据类型,汇总即有:

def parse_exmp(serial_exmp):  #label中[10]是因为一个label是一个有10个元素的列表,shape中的[x]为shape的长度 
  feats = tf.parse_single_example(serial_exmp, features={'feature':tf.FixedLenFeature([], tf.string),'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([x], tf.int64)}) 
  image = tf.decode_raw(feats['feature'], tf.float32) 
  label = feats['label'] 
  shape = tf.cast(feats['shape'], tf.int32) 
  return image, label, shape 

解析tfrecord文件中的所有记录,使用dataset的map方法,如下:

dataset = dataset.map(parse_exmp)
map方法可以接受任意函数以对dataset中的数据进行处理;另外,可使用repeat、shuffle、batch方法对dataset进行重复、混洗、分批;用repeat复制dataset以进行多个epoch;如下:
dataset = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)
解析完数据后,便可以取出数据进行使用,通过创建iterator来进行,如下:
iterator = dataset.make_one_shot_iterator() 
batch_image, batch_label, batch_shape = iterator.get_next()
要把不同dataset的数据feed进行模型,则需要先创建iterator handle,即iterator placeholder,如下:
handle = tf.placeholder(tf.string, shape=[]) 
iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types, dataset_train.output_shapes) 
image, label, shape = iterator.get_next()
然后为各个dataset创建handle,以feed_dict传入placeholder,如下:
with tf.Session() as sess:
  handle_train, handle_val, handle_test = sess.run([x.string_handle() for x in [iter_train, iter_val, iter_test]])
  sess.run([loss, train_op], feed_dict={handle: handle_train}

 

 
 
 
 
 
 
 
 
posted on 2022-04-03 11:14  bubbleeee  阅读(282)  评论(0编辑  收藏  举报