好文章 about tfrecord
深入浅出的TensorFlow数据格式化存储工具TFRecord用法教程
TFRecord是TensorFlow官方推荐使用的数据格式化存储工具,它不仅规范了数据的读写方式,还大大地提高了IO效率。
1.使用TFRecord的理由
TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,只要生成一次TFRecord,之后的数据读取和加工处理的效率都会得到提高。
而且,使用TFRecord可以直接作为Cloud ML Engine的输入数据。
一般来说,我们使用TensorFlow进行数据读取的方式有以下4种:
预先把所有数据加载进内存
在每轮训练中使用原生Python代码读取一部分数据,然后使用feed_dict输入到计算图
利用Threading和Queues从TFRecord中分批次读取数据
使用Dataset API
(1)方案对于数据量不大的场景来说是足够简单而高效的,但是随着数据量的增长,势必会对有限的内存空间带来极大的压力,还有长时间的数据预加载,甚至导致我们十分熟悉的OutOfMemoryError;
(2)方案可以一定程度上缓解了方案(1)的内存压力问题,但是由于在单线程环境下我们的IO操作一般都是同步阻塞的,势必会在一定程度上导致学习时间的增加,尤其是相同的数据需要重复多次读取的情况下;
而方案(3)和方案(4)都利用了我们的TFRecord,由于使用了多线程使得IO操作不再阻塞我们的模型训练,同时为了实现线程间的数据传输引入了Queues。
2.准备数据
下面,我们以Fashion MNIST数据集为例,介绍生成TFRecrd的方法。
所谓的Fashion MNIST数据集,其实就是大小为28*28的共10个类别的服装图像:
下面我们把数据集下载并保存到data/fashion目录下:
$ mkdir -p data/fashin $ cd data/fashion $ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz $ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz $ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz $ wget http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz $ cd ../.. --------------------- 作者:烧煤的快感 来源:CSDN 原文:https://blog.csdn.net/gg_18826075157/article/details/78449104 版权声明:本文为博主原创文章,转载请附上博文链接!
然后,我们在TensorFlow使用和MNIST数据集相同的代码进行数据读取:
from tensorflow.examples.tutorials.mnist import input_data fashion_mnist = input_data.read_data_sets('data/fashion')
3.Example记录和SequenceExample记录
使用TFRecord时,一般以tf.train.Example和tf.train.SequenceExample作为基本单位来进行数据读取。
tf.train.Example一般用于数值、图像等有固定大小的数据,同时使用tf.train.Feature指定每个记录各特征的名称和数据类型,用法如下:
tf.train.Example(features=tf.train.Features(feature={ 'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])), 'width' : tf.train.Feature(int64_list=tf.train.Int64List(value=[width])), 'depth' : tf.train.Feature(int64_list=tf.train.Int64List(value=[depth])), 'image' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])) })) --------------------- 作者:烧煤的快感 来源:CSDN 原文:https://blog.csdn.net/gg_18826075157/article/details/78449104 版权声明:本文为博主原创文章,转载请附上博文链接!
tf.train.SequenceExample一般用于文本、时间序列等没有固定长度大小的数据,用法如下:
example = tf.train.SequenceExample() # 通过context来指定数据量的大小 example.context.feature["length"].int64_list.value.append(len(data)) # 通过feature_lists来加载数据 words_list = example.feature_lists.feature_list["words"] for word in words: words_list.feature.add().int64_list.value.append(word_id(word)) --------------------- 作者:烧煤的快感 来源:CSDN 原文:https://blog.csdn.net/gg_18826075157/article/details/78449104 版权声明:本文为博主原创文章,转载请附上博文链接!
4.生成TFRecord
接下来,让我们把原始的Fashion MNIST数据集转化为TFRecord并保存下来:
import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data def make_example(image, label): return tf.train.Example(features=tf.train.Features(feature={ 'image' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])), 'label' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[label])) })) def write_tfrecord(images, labels, filename): writer = tf.python_io.TFRecordWriter(filename) for image, label in zip(images, labels): labels = labels.astype(np.float32) ex = make_example(image.tobytes(), label.tobytes()) writer.write(ex.SerializeToString()) writer.close() def main(): fashion_mnist = input_data.read_data_sets('data/fashion', one_hot=True) train_images = fashion_mnist.train.images train_labels = fashion_mnist.train.labels test_images = fashion_mnist.test.images test_labels = fashion_mnist.test.labels write_tfrecord(train_images, train_labels, 'fashion_mnist_train.tfrecord') write_tfrecord(test_images, test_labels, 'fashion_mnist_test.tfrecord') if __name__ == '__main__': main() --------------------- 作者:烧煤的快感 来源:CSDN 原文:https://blog.csdn.net/gg_18826075157/article/details/78449104 版权声明:本文为博主原创文章,转载请附上博文链接!
执行了上面的代码后,会在当前工作目录下生成两个TFRecord数据文件——fashion_mnist_train.tfrecord和fashion_mnist_test.tfrecord。
5.确认TFRecord的内容
如果我们想确认下刚才生成的TFRecord是否合乎我们的预期,tf.train.Example.FromString应该是不二之选了。
In [1]: import tensorflow as tf In [2]: example = next(tf.python_io.tf_record_iterator("fashion_mnist_train.tfrecord")) In [3]: tf.train.Example.FromString(example) Out[3]: features { feature { feature { key: "image" value { bytes_list { value: "\000...\000" } } } feature { key: "label" value { bytes_list { value: "\000...\000" } } } } --------------------- 作者:烧煤的快感 来源:CSDN 原文:https://blog.csdn.net/gg_18826075157/article/details/78449104 版权声明:本文为博主原创文章,转载请附上博文链接!
由此可知,features包含了image、label、height、width等特征。
6.读取TFRecord
为了完成这项任务,推荐使用tf.parse_single_example:
def read_tfrecord(filename): filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.string) }) image = tf.decode_raw(features['image'], tf.float32) label = tf.decode_raw(features['label'], tf.float64) image = tf.reshape(image, [28, 28, 1]) label = tf.reshape(label, [10]) image, label = tf.train.batch([image, label], batch_size=16, capacity=500) return image, label --------------------- 作者:烧煤的快感 来源:CSDN 原文:https://blog.csdn.net/gg_18826075157/article/details/78449104 版权声明:本文为博主原创文章,转载请附上博文链接!
7.整合
下面让我们把TFRecord使用到真实的模型训练场景中,虽然这次的Fashion MNIST数据量并不算大,完全可以一次性全部加载到内存中,但我们的TFRecord一样有用武之地,就是实现异步IO。
import numpy as np import tensorflow as tf import tfrecord_io from tensorflow.examples.tutorials.mnist import input_data from tensorflow.contrib import slim def model(image, label): net = slim.conv2d(image, 48, [5,5], scope='conv1') net = slim.max_pool2d(net, [2,2], scope='pool1') net = slim.conv2d(net, 96, [5,5], scope='conv2') net = slim.max_pool2d(net, [2,2], scope='pool2') net = slim.flatten(net, scope='flatten') net = slim.fully_connected(net, 512, scope='fully_connected1') logits = slim.fully_connected(net, 10, activation_fn=None, scope='fully_connected2') prob = slim.softmax(logits) loss = slim.losses.softmax_cross_entropy(logits, label) train_op = slim.optimize_loss(loss, slim.get_global_step(), learning_rate=0.001, optimizer='Adam') return train_op def main(): train_images, train_labels = tfrecord_io.read_tfrecord('fashion_mnist_train.tfrecord') train_op = model(train_images, train_labels) step = 0 with tf.Session() as sess: init_op = tf.group( tf.local_variables_initializer(), tf.global_variables_initializer()) sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) while step < 3000: sess.run([train_op]) if step % 100 == 0: print('step: {}'.format(step)) step += 1 coord.request_stop() coord.join(threads) if __name__ == '__main__': main() --------------------- 作者:烧煤的快感 来源:CSDN 原文:https://blog.csdn.net/gg_18826075157/article/details/78449104 版权声明:本文为博主原创文章,转载请附上博文链接!