Tensorflow系列——Saver的用法

摘抄自:https://blog.csdn.net/u011500062/article/details/51728830/

1、实例

 1 import tensorflow as tf
 2 import numpy as np
 3 
 4 x = tf.placeholder(tf.float32, shape=[None, 1])
 5 y = 4 * x + 4
 6 
 7 w = tf.Variable(tf.random_normal([1], -1, 1))
 8 b = tf.Variable(tf.zeros([1]))
 9 y_predict = w * x + b
10 
11 loss = tf.reduce_mean(tf.square(y - y_predict))
12 optimizer = tf.train.GradientDescentOptimizer(0.5)
13 train = optimizer.minimize(loss)
14 
15 isTrain = False
16 train_steps = 100
17 checkpoint_steps = 50
18 checkpoint_dir = './checkpoint_dir/'
19 
20 saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
21 x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
22 
23 with tf.Session() as sess:
24     sess.run(tf.initialize_all_variables())
25     if isTrain:
26         for i in range(train_steps):
27             sess.run(train, feed_dict={x: x_data})
28             if (i + 1) % checkpoint_steps == 0:
29                 saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i + 1)
30     else:
31         ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
33         if ckpt and ckpt.model_checkpoint_path:
34             saver.restore(sess, ckpt.model_checkpoint_path)
35             print("Restore Sucessfully")
36         else:
37             pass
38         print(sess.run(w))
39         print(sess.run(b))

2、运行结果

 

 

3、解释

训练阶段,每经过checkpoint_steps 步保存一次变量,保存的文件夹为checkpoint_dir

测试阶段,ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么,然后载入变量

 

posted @ 2019-02-15 10:03  笑着刻印在那一张泛黄  阅读(289)  评论(0编辑  收藏  举报