torch:
1. 通过model.state_dict()输出模型结构,结构中key是权重名称,value是权重的值
2. 根据权重名称获取权重:
fc_weight= model.state_dict()['fc_cls.weight']# 权重名称为:fc_cls.weight
tensorflow
1. 首先要知道获取哪个tensor的权重:
打印tensor的名称
tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for tmp in tvars:
print(tmp.name)
2. 根据1中的名称获取权重:
weight = tf.get_default_graph().get_tensor_by_name('variable name')
weight需要session.run()一下,返回结果
过去已逝,未来太远,只争今朝
浙公网安备 33010602011771号