yolov3-tf2 数据格式压缩

对于voc 这样的图像数据集,占用空间比较大,之前一般以矩阵形式存在内存空间中进行模型训练,需要计算机大量内存空间
tensorflow 有 tf.io.TFRecordWriter的数据api 可以将数据进行压缩,
import time
import os
import hashlib

from absl import app, flags, logging
from absl.flags import FLAGS
import tensorflow as tf
import lxml.etree
import tqdm

# flags.DEFINE_string('data_dir', './data/voc2012_raw/VOCdevkit/VOC2012/',
#                     'path to raw PASCAL VOC dataset')
# flags.DEFINE_enum('split', 'train', [
#                   'train', 'val'], 'specify train or val spit')
# flags.DEFINE_string('output_file', './data/voc2012_train.tfrecord', 'outpot dataset')
# flags.DEFINE_string('classes', './data/voc2012.names', 'classes file')


def build_example(annotation, class_map):
    img_path = os.path.join(
        './data/voc2012_raw/VOCdevkit/VOC2012', 'JPEGImages', annotation['filename'])
    img_raw = open(img_path, 'rb').read()
    key = hashlib.sha256(img_raw).hexdigest()

    width = int(annotation['size']['width'])
    height = int(annotation['size']['height'])

    xmin = []
    ymin = []
    xmax = []
    ymax = []
    classes = []
    classes_text = []
    truncated = []
    views = []
    difficult_obj = []
    if 'object' in annotation:
        for obj in annotation['object']:
            difficult = bool(int(obj['difficult']))
            difficult_obj.append(int(difficult))

            xmin.append(float(obj['bndbox']['xmin']) / width)
            ymin.append(float(obj['bndbox']['ymin']) / height)
            xmax.append(float(obj['bndbox']['xmax']) / width)
            ymax.append(float(obj['bndbox']['ymax']) / height)
            classes_text.append(obj['name'].encode('utf8'))
            classes.append(class_map[obj['name']])
            truncated.append(int(obj['truncated']))
            views.append(obj['pose'].encode('utf8'))

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
        'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
        'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[
            annotation['filename'].encode('utf8')])),
        'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[
            annotation['filename'].encode('utf8')])),
        'image/key/sha256': tf.train.Feature(bytes_list=tf.train.BytesList(value=[key.encode('utf8')])),
        'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
        'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf8')])),
        'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),
        'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmax)),
        'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymin)),
        'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymax)),
        'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=classes_text)),
        'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=classes)),
        'image/object/difficult': tf.train.Feature(int64_list=tf.train.Int64List(value=difficult_obj)),
        'image/object/truncated': tf.train.Feature(int64_list=tf.train.Int64List(value=truncated)),
        'image/object/view': tf.train.Feature(bytes_list=tf.train.BytesList(value=views)),
    }))
    return example


def parse_xml(xml):
    if not len(xml):
        return {xml.tag: xml.text}
    result = {}
    for child in xml:
        child_result = parse_xml(child)
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])
    return {xml.tag: result}


def main():
    
    #读取数据,并且对类别进行编码,这是一个人字典
    class_map = {name: idx for idx, name in enumerate(
        open('./data/voc2012.names').read().splitlines())}

    #在这里,用tf2 的新的数据格式,将图像信息重新编码,而不在是直接用矩阵格式
    #原因是编码后数据占用空间比较低,
    writer = tf.io.TFRecordWriter('./data/voc2012_train.tfrecord')
    image_list = open(os.path.join(
        './data/voc2012_raw/VOCdevkit/VOC2012', 'ImageSets', 'Main', '%s.txt' % 'train')).read().splitlines()
    logging.info("Image list loaded: %d", len(image_list))
    for name in tqdm.tqdm(image_list):
        annotation_xml = os.path.join(
            './data/voc2012_raw/VOCdevkit/VOC2012', 'Annotations', name + '.xml')
        annotation_xml = lxml.etree.fromstring(open(annotation_xml).read())
        annotation = parse_xml(annotation_xml)['annotation']
        tf_example = build_example(annotation, class_map)
        writer.write(tf_example.SerializeToString())
    writer.close()
    logging.info("Done")


main()

# if __name__ == '__main__':
#     app.run(main)

posted @ 2022-08-19 22:49  luoganttcc  阅读(12)  评论(0)    收藏  举报