torch:

1. 通过model.state_dict()输出模型结构,结构中key是权重名称,value是权重的值

2. 根据权重名称获取权重:

  fc_weight= model.state_dict()['fc_cls.weight']# 权重名称为:fc_cls.weight

tensorflow

1. 首先要知道获取哪个tensor的权重:

    打印tensor的名称 

  tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
  for tmp in tvars:
    print(tmp.name)

2. 根据1中的名称获取权重:

    weight =  tf.get_default_graph().get_tensor_by_name('variable name')

         weight需要session.run()一下,返回结果

posted on 2023-09-18 18:54  一点飞鸿  阅读(185)  评论(0)    收藏  举报