TensorFlow模型的持久化
保存模型
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
#声明两个变量并计算它们的和。
v1 = tf.Variable(tf.constant(1.0 , shape=[1]) , name="v1")
v2 = tf.Variable(tf.constant(2.0 , shape=[1]) , name="v2")
result = v1 + v2
init_op= tf.global_variables_initializer()
#声明 tf.train.Saver 类用于保存模型。
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
#将模型保存到/model/model.ckpt 文件。
saver.save(sess,"model/model.ckpt")
- 运行之后model文件夹会生成四个文件

1. model.ckpt.meta 保存了tensorflow计算图的结构
2.model.ckpt保存了每一个变量的取值
3. checkpoint文件报错了一个目录下所有的模型文件列表
加载模型
- 和保存模型的代码几乎是一样的。唯一不同的是,在加载模型的代码中没有运行变量的初始化过程,是直接将已经保存过的模型加载进来。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
#使用和保存模型代码中一样的方式来卢明变量。
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
#加载己经保存的模型,并通过已经保存的模型中变量的值来计算加法。
saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./"
print(sess.run(result))
# 输出[ 3 .]

直接加载
- 直接加载模型中的全部变量。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
#直接加载持久化的图。
saver = tf.train.import_meta_graph ("model/model.ckpt.meta")
with tf.Session() as sess:
saver.restore(sess , "model/model.ckpt")
#通过张量的名称来获取张量。
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
#输出[ 3 .]

加载部分变量
- 使用tf.train.Saver([v1])只加载变量v1
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
saver = tf.train.Saver([v1])
with tf.Session() as sess:
saver.restore(sess, "./Model/model.ckpt")
print(sess.run(result))
- 报错提示变量v2未初始化。

加载时重命名变量
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
#重命名变量名称
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="rename-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="rename-v2")
result = v1 + v2
saver = tf.train.Saver({"v1":v1,"v2":v2})
with tf.Session() as sess:
#加载己经保存的模型,并通过已经保存的模型中变量的值来计算加法。
saver.restore(sess, "model/model.ckpt") # 注意此处路径前添加"./"
print(sess.run(result))
# 输出[ 3 .]


浙公网安备 33010602011771号