import numpy as np
def parse_tfrec(filename):
# 遍历每条样本
feature = None
for raw_record in tf.data.TFRecordDataset(filename):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
feature = example.features.feature
# print(list(feature.keys()))
return feature
def show_info(feature):
keys = list(feature.keys())
for k in keys:
# print(feature[k])
if feature[k].HasField('bytes_list'):
print(k, "bytes_list", feature[k].bytes_list.value)
if feature[k].HasField('float_list'):
feature_numpy = np.array(feature[k].float_list.value)
if len(feature_numpy) == 1:
print(k, "float_list", feature_numpy)
else:
print(k, "float_list",feature_numpy.shape)
# else:
# print(feature[k])
# print(feature['year'])
feature = parse_tfrec("./samples_split/AO7_2016_2_8.tfrec")
show_info(feature)