tensorflow pb模型

 

import tensorflow as tf  
from tensorflow.python.framework import graph_util  
  
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")  
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")  
result = v1 + v2  
  
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())  
    # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分  
    graph_def = tf.get_default_graph().as_graph_def()  
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])  
  
    with tf.gfile.GFile("D:/model/testAdd/combined_model.pb", 'wb') as f:  
        f.write(output_graph_def.SerializeToString()) 

# with tf.Session() as sess:
#     sess.run(tf.global_variables_initializer())
#     print("---------")
#     print(sess.run(result))

生成该模型D:\model\testAdd\combined_model.pb

加载模型并计算

import tensorflow as tf  
from tensorflow.python.platform import gfile  
  
with tf.Session() as sess:  
    model_filename = "D:/model/testAdd/combined_model.pb"  
    with gfile.FastGFile(model_filename, 'rb') as f:  
        graph_def = tf.GraphDef()  
        graph_def.ParseFromString(f.read())  
  
    result = tf.import_graph_def(graph_def, return_elements=["add:0"])  
    print(sess.run(result)) # [array([ 3.], dtype=float32)]

 

带参数的例子

# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.python.framework import graph_util

with tf.Session() as sess:
    matrix = tf.placeholder(tf.int32, [1, 2], name='matrix_')
    op = tf.add(matrix, matrix, name='output')
    sess.run(tf.global_variables_initializer())
    #下面一行可以省去,仅作测试
    print(sess.run(op, feed_dict={matrix:[[2, 3]]}))

    #保存模型
    output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
    with tf.gfile.GFile('D:/model/cxq.pb', mode='wb') as f:
        f.write(output_graph_def.SerializeToString())
        #[array([[2, 4]])]

加载模型

import tensorflow as tf
from tensorflow.python.platform import gfile  
  
matrix1 = tf.placeholder(tf.int32, [1, 2])

with tf.Session() as sess:  
    model_filename = "D:/model/cxq.pb"  
    with gfile.GFile(model_filename, 'rb') as f:  
        graph_def = tf.GraphDef()  
        graph_def.ParseFromString(f.read())  
        output = tf.import_graph_def(graph_def, input_map={'matrix_:0': matrix1}, return_elements=['output:0'])
        print(output)
        text_list = sess.run(output, feed_dict={matrix1: [[1, 2]]})
        print(text_list)

 

posted @ 2019-04-11 17:15  牧 天  阅读(652)  评论(0)    收藏  举报