python读取tfrecord单个数据

import numpy as np
def parse_tfrec(filename):
    # 遍历每条样本
    feature = None
    for raw_record in tf.data.TFRecordDataset(filename):
        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
        feature = example.features.feature
        # print(list(feature.keys()))
    return feature

def show_info(feature):
    keys = list(feature.keys())
    for k in keys:
        # print(feature[k])
        if feature[k].HasField('bytes_list'):
            print(k, "bytes_list", feature[k].bytes_list.value)
        if feature[k].HasField('float_list'):
            feature_numpy = np.array(feature[k].float_list.value)
            if len(feature_numpy) == 1:
                print(k, "float_list", feature_numpy)
            else:
                print(k, "float_list",feature_numpy.shape)
        # else:
        #     print(feature[k])
    # print(feature['year'])

feature = parse_tfrec("./samples_split/AO7_2016_2_8.tfrec")
show_info(feature)
posted @ 2025-02-15 21:23  提高效率!  阅读(20)  评论(0)    收藏  举报