tensorflow pb模型
import tensorflow as tf from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") result = v1 + v2 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分 graph_def = tf.get_default_graph().as_graph_def() output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add']) with tf.gfile.GFile("D:/model/testAdd/combined_model.pb", 'wb') as f: f.write(output_graph_def.SerializeToString()) # with tf.Session() as sess: # sess.run(tf.global_variables_initializer()) # print("---------") # print(sess.run(result))
生成该模型D:\model\testAdd\combined_model.pb
加载模型并计算
import tensorflow as tf from tensorflow.python.platform import gfile with tf.Session() as sess: model_filename = "D:/model/testAdd/combined_model.pb" with gfile.FastGFile(model_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) result = tf.import_graph_def(graph_def, return_elements=["add:0"]) print(sess.run(result)) # [array([ 3.], dtype=float32)]
带参数的例子
# -*- coding:utf-8 -*- import tensorflow as tf from tensorflow.python.framework import graph_util with tf.Session() as sess: matrix = tf.placeholder(tf.int32, [1, 2], name='matrix_') op = tf.add(matrix, matrix, name='output') sess.run(tf.global_variables_initializer()) #下面一行可以省去,仅作测试 print(sess.run(op, feed_dict={matrix:[[2, 3]]})) #保存模型 output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output']) with tf.gfile.GFile('D:/model/cxq.pb', mode='wb') as f: f.write(output_graph_def.SerializeToString()) #[array([[2, 4]])]
加载模型
import tensorflow as tf from tensorflow.python.platform import gfile matrix1 = tf.placeholder(tf.int32, [1, 2]) with tf.Session() as sess: model_filename = "D:/model/cxq.pb" with gfile.GFile(model_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'matrix_:0': matrix1}, return_elements=['output:0']) print(output) text_list = sess.run(output, feed_dict={matrix1: [[1, 2]]}) print(text_list)

浙公网安备 33010602011771号