常用函数
获取模型输入节点信息
import tensorflow as tf
from tensorflow.python.tools import saved_model_utils
model_dir = 'model_dir'
meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tf.saved_model.SERVING)
signatures = meta_graph_def.signature_def
input_tensor_names = {}
for sig_name in signatures:
for input_name in signatures[sig_name].inputs:
input_tensor_shape = []
input_tensor = signatures[sig_name].inputs[input_name]
for dim in input_tensor.tensor_shape.dim:
input_tensor_shape.append(int(dim.size))
input_tensor_names[input_name] = input_tensor.name
print(input_tensor_names)
获取模型输出节点信息
import tensorflow as tf
from tensorflow.python.tools import saved_model_utils
model_dir = 'model_dir'
meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tf.saved_model.SERVING)
signatures = meta_graph_def.signature_def
output_tensor_names = {}
for sig_name in signatures:
for output_name in signatures[sig_name].outputs:
output_tensor_shape = []
output_tensor = signatures[sig_name].outputs[output_name]
for dim in output_tensor.tensor_shape.dim:
output_tensor_shape.append(int(dim.size))
output_tensor_names[output_name] = output_tensor.name
print(output_tensor_names)