Tensorflow 在一个新图中导入已训练好的旧图的某个变量
需求如题,不同于复用整个网络,只需要复用某个变量然后重新构图。
预训练文件简化为:
import tensorflow as tf import numpy as np w = tf.Variable(10,name='weightaaa') b = tf.Variable(2,name='biasaaa') x = tf.placeholder(dtype=tf.int32,shape=()) y = tf.add(tf.multiply(w,x),b) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) _y = sess.run(y,feed_dict={x:1}) print("otuput:%s"%(_y)) saver.save(sess,'./testsavercheckpoint_dir/testsavermodel1')
懵逼着尝试了:
import tensorflow as tf w = tf.Variable(5,name='weight') b = tf.Variable(2,name='bias') x = tf.placeholder(dtype=tf.int32,shape=()) y = tf.add(tf.multiply(w,x),b) with tf.Session() as sess: saver = tf.train.import_meta_graph('./testsavercheckpoint_dir/testsavermodel1.meta') saver.restore(sess,tf.train.latest_checkpoint('./testsavercheckpoint_dir')) w = tf.get_default_graph().get_tensor_by_name('weight:0') sess.run(tf.global_variables_initializer()) _y = sess.run(y,feed_dict={x:1}) print("otuput:%s"%(_y))
结果并不是12,而是7.想想也是,静态图已经构建好了,sess直接用静态途中的‘weight’了,还有我调用的旧的啥事儿。
于是
import tensorflow as tf sess = tf.Session() saver = tf.train.import_meta_graph('./testsavercheckpoint_dir/testsavermodel1.meta') saver.restore(sess,tf.train.latest_checkpoint('./testsavercheckpoint_dir')) graph = tf.get_default_graph() w_temp = tf.Variable(5,name='weight') w = graph.get_tensor_by_name('weightaaa:0') b = tf.Variable(2,name='bias') x = tf.placeholder(dtype=tf.int32,shape=()) y = tf.add(tf.multiply(w,x),b) sess.run(tf.global_variables_initializer()) _y = sess.run(y,feed_dict={x:1}) print("otuput:%s"%(_y)) for i in tf.global_variables(): print(i)
即可。
注意tf里变量名会是给的name[_number]:0。如果重名,用number区分。
比如第一个weight就是weight:0,第二个就是weight_1:0。
参考资料:
https://blog.csdn.net/huachao1001/article/details/78501928

浙公网安备 33010602011771号