对于voc 这样的图像数据集,占用空间比较大,之前一般以矩阵形式存在内存空间中进行模型训练,需要计算机大量内存空间
tensorflow 有 tf.io.TFRecordWriter的数据api 可以将数据进行压缩,
import time
import os
import hashlib
from absl import app, flags, logging
from absl.flags import FLAGS
import tensorflow as tf
import lxml.etree
import tqdm
# flags.DEFINE_string('data_dir', './data/voc2012_raw/VOCdevkit/VOC2012/',
# 'path to raw PASCAL VOC dataset')
# flags.DEFINE_enum('split', 'train', [
# 'train', 'val'], 'specify train or val spit')
# flags.DEFINE_string('output_file', './data/voc2012_train.tfrecord', 'outpot dataset')
# flags.DEFINE_string('classes', './data/voc2012.names', 'classes file')
def build_example(annotation, class_map):
img_path = os.path.join(
'./data/voc2012_raw/VOCdevkit/VOC2012', 'JPEGImages', annotation['filename'])
img_raw = open(img_path, 'rb').read()
key = hashlib.sha256(img_raw).hexdigest()
width = int(annotation['size']['width'])
height = int(annotation['size']['height'])
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
views = []
difficult_obj = []
if 'object' in annotation:
for obj in annotation['object']:
difficult = bool(int(obj['difficult']))
difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
classes_text.append(obj['name'].encode('utf8'))
classes.append(class_map[obj['name']])
truncated.append(int(obj['truncated']))
views.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[
annotation['filename'].encode('utf8')])),
'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[
annotation['filename'].encode('utf8')])),
'image/key/sha256': tf.train.Feature(bytes_list=tf.train.BytesList(value=[key.encode('utf8')])),
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf8')])),
'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),
'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmax)),
'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymin)),
'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymax)),
'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=classes_text)),
'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=classes)),
'image/object/difficult': tf.train.Feature(int64_list=tf.train.Int64List(value=difficult_obj)),
'image/object/truncated': tf.train.Feature(int64_list=tf.train.Int64List(value=truncated)),
'image/object/view': tf.train.Feature(bytes_list=tf.train.BytesList(value=views)),
}))
return example
def parse_xml(xml):
if not len(xml):
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = parse_xml(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
def main():
#读取数据,并且对类别进行编码,这是一个人字典
class_map = {name: idx for idx, name in enumerate(
open('./data/voc2012.names').read().splitlines())}
#在这里,用tf2 的新的数据格式,将图像信息重新编码,而不在是直接用矩阵格式
#原因是编码后数据占用空间比较低,
writer = tf.io.TFRecordWriter('./data/voc2012_train.tfrecord')
image_list = open(os.path.join(
'./data/voc2012_raw/VOCdevkit/VOC2012', 'ImageSets', 'Main', '%s.txt' % 'train')).read().splitlines()
logging.info("Image list loaded: %d", len(image_list))
for name in tqdm.tqdm(image_list):
annotation_xml = os.path.join(
'./data/voc2012_raw/VOCdevkit/VOC2012', 'Annotations', name + '.xml')
annotation_xml = lxml.etree.fromstring(open(annotation_xml).read())
annotation = parse_xml(annotation_xml)['annotation']
tf_example = build_example(annotation, class_map)
writer.write(tf_example.SerializeToString())
writer.close()
logging.info("Done")
main()
# if __name__ == '__main__':
# app.run(main)