TFRecord 和 tf.Example
当利用tf.data读取数据时成为训练的瓶颈,可能需要转换TFRecord 格式。
协议缓冲区是一个跨平台、跨语言的库,用于高效地序列化结构化数据。协议消息由 .proto 文件定义。 tf.Example 消息(或 protobuf)是一种灵活的消息类型,表示 {"string": value} 映射。它专为 TensorFlow 而设计,并被用于 TFX 等高级 API。
根据官方文档下面介绍创建、解析和使用 tf.Example 消息,以及如何在 .tfrecord 文件之间对 tf.Example 消息进行序列化、写入和读取。
1. tf.example
tf.Example 是 {"string": tf.train.Feature} 映射。如何从已有的数据创建tf.Example消息呢?先准备一下已有的数据集:(比较特殊,非图像文本,由多类型组成)
# The number of observations in the dataset. n_observations = int(1e4) # Boolean feature, encoded as False or True. feature0 = np.random.choice([False, True], n_observations) # Integer feature, random from 0 to 4. feature1 = np.random.randint(0, 5, n_observations) # String feature strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat']) feature2 = strings[feature1] # Float feature, from a standard normal distribution feature3 = np.random.randn(n_observations)
上面的数据集具有 4 个特征:
- 具有相等
False或True概率的布尔特征 - 从
[0, 5]均匀随机选择的整数特征 - 通过将整数特征作为索引从字符串表生成的字符串特征
- 来自标准正态分布的浮点特征
要将这些numpy特征转为tf.Example需要几个转换函数:下面的函数将标准 TensorFlow 类型转换为兼容 tf.Example 的 tf.train.Feature
# The following functions can be used to convert a value to a type compatible # with tf.Example. def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): value = value.numpy() # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): """Returns a float_list from a float / double.""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
有了辅助函数,已有数据集转为tf.Example消息的方式有三步:
-
在每个观测结果中,需要使用上述其中一种函数,将每个值转换为包含三种兼容类型之一的
tf.train.Feature。 -
创建一个从特征名称字符串到第 1 步中生成的编码特征值的映射(字典)。
-
将第 2 步中生成的映射转换为
Features消息。
举个例子,假设从数据集中获得了某一条数据: [False, 4, bytes('goat'), 0.9876]。可以将该条数据转为 tf.Example 消息。根据上面的3个步骤可以实现如下:
def serialize_example(feature0, feature1, feature2, feature3): """ Creates a tf.Example message ready to be written to a file. """ # Create a dictionary mapping the feature name to the tf.Example-compatible # data type. feature = { 'feature0': _int64_feature(feature0), # 因为输入的数据有4哥feature,所以这里写成4条,并写成字典形式 {特征名称:tf.train.Feature} 'feature1': _int64_feature(feature1), 'feature2': _bytes_feature(feature2), 'feature3': _float_feature(feature3), } # Create a Features message using tf.train.Example. example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) # 将映射转为tf.train.Features(注意这里有个s) return example_proto # .SerializeToString() ⚠️⚠️⚠️ 下面涉及到保存数据,一定要先序列化成二进制字符串才能存为tfrecord!
用一条数据测试一下:
example_observation = [] serialized_example = serialize_example(False, 4, b'goat', 0.9876) serialized_example
执行结果:
features { feature { key: "feature0" value { int64_list { value: 0 } } } feature { key: "feature1" value { int64_list { value: 4 } } } feature { key: "feature2" value { bytes_list { value: "goat" } } } feature { key: "feature3" value { float_list { value: 0.9876000285148621 } } } }
可以将这个 serialized_example 使用 .SerializeToString 方法序列化为二进制字符串:
serialized_example.SerializeToString()
结果为:
b'\nR\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04'
2. tf.Record
对于这种 [False, 4, bytes('goat'), 0.9876] 特殊的数据集,比如这4个feature,可以通过tf.Record与tf.Example交互的方式来构建处理数据。
要将数据放入数据集中,最简单的方式是使用 from_tensor_slices 方法。例如可以将其中的一个feature放进来:
tf.data.Dataset.from_tensor_slices(feature1)
类型: <TensorSliceDataset shapes: (), types: tf.int64>
features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
features_dataset
类型: <TensorSliceDataset shapes: ((), (), (), ()), types: (tf.bool, tf.int64, tf.string, tf.float64)>
通过take方法查看每条数据:
for f0,f1,f2,f3 in features_dataset.take(1): print(f0) print(f1) print(f2) print(f3)
tf.Tensor(False, shape=(), dtype=bool) tf.Tensor(1, shape=(), dtype=int64) tf.Tensor(b'dog', shape=(), dtype=string) tf.Tensor(-0.07658295354196158, shape=(), dtype=float64)
使用 tf.data.Dataset.map 方法可将函数应用于 Dataset 的每个元素。对于上面的features_dataset,可以通过map函数对其中的每一条数据进行map,从而转换为tf.Example格式。首先定义这个map函数:
def tf_serialize_example(f0,f1,f2,f3): tf_string = tf.py_function( serialize_example, # 调用了上面定义的转换函数 (f0,f1,f2,f3), # pass these args to the above function. tf.string) # the return type is `tf.string`. return tf.reshape(tf_string, ()) # The result is a scalar
进行转换:
serialized_features_dataset = features_dataset.map(tf_serialize_example) # 此函数将会对features_dataset中的每条数据进行转换
serialized_features_dataset # 这个dataset类可以存为tfrecord
输出为: <MapDataset shapes: (), types: tf.string> 。官方也给出下面通过生成器的方式保存为tfrecord:
def generator(): for features in features_dataset: yield serialize_example(*features) # 转为tf.Example
serialized_features_dataset = tf.data.Dataset.from_generator( generator, output_types=tf.string, output_shapes=()) serialized_features_dataset
输出为: <FlatMapDataset shapes: (), types: tf.string>
filename = 'test.tfrecord' # 保存为tfrecord writer = tf.data.experimental.TFRecordWriter(filename) writer.write(serialized_features_dataset)
3. 读取tfrecord
raw_dataset = tf.data.TFRecordDataset('test.tfrecord')
raw_dataset
读进来后,数据集包含序列化的 tf.train.Example 消息。迭代时,它会将其作为标量字符串张量返回。使用 .take 方法显示前 3条记录
for raw_record in raw_dataset.take(3): print(repr(raw_record))
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xbd\x15\x9b?\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xbb\x0b\x04?\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xa1\x9c\x98\xbf'>
for raw_record in raw_dataset.take(1): example = tf.train.Example() example.ParseFromString(raw_record.numpy()) # 使用parsefromstring函数来解析二进制record字符串 print(example)
可以使用以下函数对这些张量进行解析。请注意,这里的 feature_description 是必需的,因为数据集使用计算图执行,并且需要以下描述来构建它们的形状和类型签名:
# Create a description of the features. feature_description = { 'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0), 'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0), 'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''), 'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0), } def _parse_function(example_proto): # 用上面的特征字典来解析tf.Example # Parse the input `tf.Example` proto using the dictionary above. return tf.io.parse_single_example(example_proto, feature_description)
parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset
<MapDataset shapes: {feature0: (), feature1: (), feature2: (), feature3: ()}, types: {feature0: tf.int64, feature1: tf.int64, feature2: tf.string, feature3: tf.float32}>
for parsed_record in parsed_dataset.take(3): print(repr(parsed_record))
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=1.2116009>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.515804>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-1.1922799>}
从上面的解析结果可以看到feature0到feature3的numpy值。也可以单独取到某组特征:
for parsed_record in parsed_dataset.take(10): print(repr(parsed_record['feature0'].numpy())) # 比如只取到feature0
可以将解析后的数据组成训练数据:
parsed_dataset.repeat(num_epochs).batch(batch_size).prefetch(buffer_size=xxx)
可以在解析example时改变特征值:
feature_description = { 'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0), 'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0), 'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''), 'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0), } def _parse_function(example_proto): # Parse the input `tf.Example` proto using the dictionary above. example = tf.io.parse_single_example(example_proto, feature_description) # example['feature0']+=1112 # 给某个特征+1112 # del example['feature0'] # 删掉某个特征 return example parsed_dataset = raw_dataset.map(_parse_function) parsed_dataset
官方文档中还提供了将图像保存为tfrecord并读取回来进行显示的例子:传送门。

浙公网安备 33010602011771号