1. 确定权重名称:

tvars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

for tmp in tvars1:

  print('all-->',tmp.name)

2. 根据网络结构从1中找到想要打印的权重名称 weight_name,通过下面的方式进行打印

fc_logits=tf.get_default_graph().get_tensor_by_name(weight_name)

with tf.Session() as sess:
  init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
  sess.run(init_op)
  fc_logits_ = sess.run(fc_logits,feed_dict={input_placeholder: gen_input,gt: label})
  print('fc_logits_:',fc_logits_)

 

  

posted on 2023-05-09 19:05  一点飞鸿  阅读(80)  评论(0编辑  收藏  举报