tf.lookup.StaticHashTable 用法
tf.lookup.StaticHashTable 本质是tensorflow 内置字典,在yolov3 tf代码中多次应用
def load_tfrecord_dataset(file_pattern, class_file, size=416):
LINE_NUMBER = -1 # TODO: use tf.lookup.TextFileIndex.LINE_NUMBER
class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)
files = tf.data.Dataset.list_files(file_pattern)
dataset = files.flat_map(tf.data.TFRecordDataset)
return dataset.map(lambda x: parse_tfrecord(x, class_table, size))
###### 在当前目录下新建文件 voc2012.names
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
import tensorflow as tf
class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)
class_table.lookup(tf.constant(['cat','person']))
<tf.Tensor: shape=(2,), dtype=int64, numpy=array([ 7, 14])>
class_table.export()
(<tf.Tensor: shape=(20,), dtype=string, numpy=
array([b'cat', b'chair', b'dog', b'person', b'bird', b'motorbike',
b'bottle', b'car', b'bus', b'sheep', b'boat', b'train',
b'aeroplane', b'pottedplant', b'sofa', b'tvmonitor', b'cow',
b'diningtable', b'horse', b'bicycle'], dtype=object)>,
<tf.Tensor: shape=(20,), dtype=int64, numpy=
array([ 7, 8, 11, 14, 2, 13, 4, 6, 5, 16, 3, 18, 0, 15, 17, 19, 9,
10, 12, 1])>)

浙公网安备 33010602011771号