C++ 解析并读取 TFRecord 文件 using Protocal Buffer API
背景介绍
最近实习中遇到了C++解析TFRecord的需求,搜索一圈发现虽然tensorflow C++ API中提供了TFRecordReader的接口,但是编译C++版本的Tensorflow并不容易&很不清真,把他当做自己的项目的依赖就更离谱了。内网外网找了很久都发现没有相关的教程,于是调研了一圈,写了个自定义的解析脚本,只需要安装了解protobuf即可使用。读懂本文以及使用对应代码需要对protobuf有一定了解。
TFRecord文件格式
TFRecord的官方文档说明了TFRecord的内容由若干tf.train.Example组成,每条tf.train.Example其实就是一个protobuf的message,而这个message的定义文件就是tensorflow的代码库中的example.proto文件(目前的文件链接)。
值得注意的是,为了保证数据的正确性,TFRecord给每一个tf.train.Example添加了一些header和footer,用来描述二进制数据的Bytes长度以及进行crc校验。这些额外的信息对应的定义在tensorflow的代码库中的record_writer.h文件里(目前的文件链接),最重要的就是下面几行
class RecordWriter {
public:
// Format of a single record:
// uint64 length
// uint32 masked crc of length
// byte data[length]
// uint32 masked crc of data
static constexpr size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
static constexpr size_t kFooterSize = sizeof(uint32);
可以看到,对应tf.train.Example序列化的有效数据的就是中间的data字段,data的长度是length字段,此外这俩字段都还有一个crc校验字段。
解析思路和代码描述
这样解析tfrecord的思路就很清晰了:
- 拿到一个tfrecord文件,以二进制模式读
- 前8个字节读到一个uint64变量中,获取length信息
- 跳过4个字节的length的crc校验码(只要胆子大)
- 读取length个字节的data,交给protobuf接口来解析
- 跳过4个字节的data的crc校验码
- 重复2-5步,直到解析完整个文件
前面已经说了,tf.train.Example的定义在example.proto文件中,这个文件还引用了同目录下的feature.proto,直接用protoc编译这俩文件,可以获得解析tf.train.Example二进制序列的API,然后在你自己的读写脚本中使用API即可【嵌套关系是:class Example的内容是 class Features,class Features的内容是若干个class Feature,class Feature的内容是以下三个中的一个:class Int64List class BytesList class FloatList。多看几遍example.pb.h和feature.pb.h琢磨琢磨即可。】
代码放在github了: https://github.com/initzhang/TFRecord-Parser

浙公网安备 33010602011771号