Tensorflow 在一个新图中导入已训练好的旧图的某个变量

需求如题,不同于复用整个网络,只需要复用某个变量然后重新构图。

预训练文件简化为:

import tensorflow as tf 
import numpy as np 

w = tf.Variable(10,name='weightaaa')
b = tf.Variable(2,name='biasaaa')

x = tf.placeholder(dtype=tf.int32,shape=())
y = tf.add(tf.multiply(w,x),b)

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    _y = sess.run(y,feed_dict={x:1})
    print("otuput:%s"%(_y))
    saver.save(sess,'./testsavercheckpoint_dir/testsavermodel1')

懵逼着尝试了:

import tensorflow as tf 

w = tf.Variable(5,name='weight')
b = tf.Variable(2,name='bias')

x = tf.placeholder(dtype=tf.int32,shape=())
y = tf.add(tf.multiply(w,x),b)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('./testsavercheckpoint_dir/testsavermodel1.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./testsavercheckpoint_dir'))
    w = tf.get_default_graph().get_tensor_by_name('weight:0')
    sess.run(tf.global_variables_initializer())
    _y = sess.run(y,feed_dict={x:1})
    print("otuput:%s"%(_y))
    

结果并不是12,而是7.想想也是,静态图已经构建好了,sess直接用静态途中的‘weight’了,还有我调用的旧的啥事儿。

于是

import tensorflow as tf 

sess = tf.Session()
saver = tf.train.import_meta_graph('./testsavercheckpoint_dir/testsavermodel1.meta')
saver.restore(sess,tf.train.latest_checkpoint('./testsavercheckpoint_dir'))

graph = tf.get_default_graph()
w_temp = tf.Variable(5,name='weight')
w = graph.get_tensor_by_name('weightaaa:0')
b = tf.Variable(2,name='bias')

x = tf.placeholder(dtype=tf.int32,shape=())
y = tf.add(tf.multiply(w,x),b)

sess.run(tf.global_variables_initializer())
_y = sess.run(y,feed_dict={x:1})
print("otuput:%s"%(_y))
for i in tf.global_variables():
    print(i)

即可。

注意tf里变量名会是给的name[_number]:0。如果重名,用number区分。

比如第一个weight就是weight:0,第二个就是weight_1:0。

 

参考资料:

https://blog.csdn.net/huachao1001/article/details/78501928

posted @ 2018-11-29 17:25  大胖子球花  阅读(449)  评论(0)    收藏  举报