C++ 解析并读取 TFRecord 文件 using Protocal Buffer API

背景介绍

最近实习中遇到了C++解析TFRecord的需求,搜索一圈发现虽然tensorflow C++ API中提供了TFRecordReader的接口,但是编译C++版本的Tensorflow并不容易&很不清真,把他当做自己的项目的依赖就更离谱了。内网外网找了很久都发现没有相关的教程,于是调研了一圈,写了个自定义的解析脚本,只需要安装了解protobuf即可使用。读懂本文以及使用对应代码需要对protobuf有一定了解。

TFRecord文件格式

TFRecord的官方文档说明了TFRecord的内容由若干tf.train.Example组成,每条tf.train.Example其实就是一个protobuf的message,而这个message的定义文件就是tensorflow的代码库中的example.proto文件(目前的文件链接)。

值得注意的是,为了保证数据的正确性,TFRecord给每一个tf.train.Example添加了一些header和footer,用来描述二进制数据的Bytes长度以及进行crc校验。这些额外的信息对应的定义在tensorflow的代码库中的record_writer.h文件里(目前的文件链接),最重要的就是下面几行

class RecordWriter {
 public:
  // Format of a single record:
  //  uint64    length
  //  uint32    masked crc of length
  //  byte      data[length]
  //  uint32    masked crc of data
  static constexpr size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
  static constexpr size_t kFooterSize = sizeof(uint32);

可以看到,对应tf.train.Example序列化的有效数据的就是中间的data字段,data的长度是length字段,此外这俩字段都还有一个crc校验字段。

解析思路和代码描述

这样解析tfrecord的思路就很清晰了:

  1. 拿到一个tfrecord文件,以二进制模式读
  2. 前8个字节读到一个uint64变量中,获取length信息
  3. 跳过4个字节的length的crc校验码(只要胆子大)
  4. 读取length个字节的data,交给protobuf接口来解析
  5. 跳过4个字节的data的crc校验码
  6. 重复2-5步,直到解析完整个文件

前面已经说了,tf.train.Example的定义在example.proto文件中,这个文件还引用了同目录下的feature.proto,直接用protoc编译这俩文件,可以获得解析tf.train.Example二进制序列的API,然后在你自己的读写脚本中使用API即可【嵌套关系是:class Example的内容是 class Features,class Features的内容是若干个class Feature,class Feature的内容是以下三个中的一个:class Int64List class BytesList class FloatList。多看几遍example.pb.h和feature.pb.h琢磨琢磨即可。】

代码放在github了: https://github.com/initzhang/TFRecord-Parser

posted @ 2021-01-15 15:57  initzhang  阅读(413)  评论(0)    收藏  举报