tensorflow 如何获取graph中的所有tensor name

import tensorflow as tf

saved_model_dir = "./saved_model"

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, ["serve"], saved_model_dir)
    graph = tf.get_default_graph()
    [print(n.name) for n in tf.get_default_graph().as_graph_def().node]

# 得到name之后,就可以获取相应的tensor了,例如:
# input_tensor = sess.graph.get_tensor_by_name('input:0')
# output_tensor = sess.graph.get_tensor_by_name('output:0')
posted @ 2020-04-02 15:51  ZH奶酪  阅读(2996)  评论(0编辑  收藏  举报