Tensorflow savedmodel to graph def
1.使用tf2onnx工具,把saved model转换为tf的graph def(不带function,也就是tf1的计算图)
https://github.com/onnx/tensorflow-onnx/blob/v1.9.3/tf2onnx/tf_loader.py
# -*- coding: utf-8 -*-
import os
import multiprocessing
from typing import List, Dict
try:
from tf2onnx import tf_loader
except ImportError:
# install tf2onnx
import subprocess
subprocess.call(["sudo", "/usr/bin/python3", "-m", "pip", "install", "tf2onnx==1.9.3"])
from tf2onnx import tf_loader
from tensorflow.core.protobuf import meta_graph_pb2, config_pb2
from tensorflow.python.grappler import tf_optimizer
from google.protobuf import text_format
from tensorflow.core.protobuf import rewriter_config_pb2
import tensorflow as tf
DEFAULT_OPTIMIZERS = ('dependency',)
def run_graph_grappler(graph, inputs, outputs, optimizers=DEFAULT_OPTIMIZERS):
tf.compat.v1.disable_eager_execution()
config = config_pb2.ConfigProto()
config.graph_options.rewrite_options.optimizers.extend(optimizers)
config.graph_options.rewrite_options.meta_optimizer_iterations = rewriter_config_pb2.RewriterConfig.ONE
meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph)
fetch_collection = meta_graph_pb2.CollectionDef()
fetch_collection.node_list.value.extend(inputs)
fetch_collection.node_list.value.extend(outputs)
meta_graph.collection_def['train_op'].CopyFrom(fetch_collection)
graph_def = tf_optimizer.OptimizeGraph(config, meta_graph)
return graph_def
def is_control_dependency(node_name: str) -> bool:
return node_name.startswith("^")
def is_saved_model_control_node(node: tf.compat.v1.NodeDef) -> bool:
'''
control node looks like:
node {
name: "Func/StatefulPartitionedCall/input_control_node/_0"
op: "NoOp"
input: "^deep_fm4_1024"
input: "^deep_fm4_1552"
}
such nodes should be removed if we need to inference the subgraph
'''
if node.op != "NoOp":
return False
if "input_control_node" not in node.name and "output_control_node" not in node.name:
return False
return all([is_control_dependency(input_name) for input_name in node.input])
def fix_saved_model_control_dependency(graph_def: tf.compat.v1.GraphDef):
saved_model_control_nodes = set()
# collect input_control_node
for node in graph_def.node:
if is_saved_model_control_node(node):
saved_model_control_nodes.update(["^" + node.name])
# remove input_control_node dependencies from normal node inputs
for node in graph_def.node:
for i in reversed(range(len(node.input))):
input_name = node.input[i]
if input_name in saved_model_control_nodes:
# safe deletion in iteration
node.input[i], node.input[-1] = node.input[-1], node.input[i]
del node.input[-1]
return graph_def
def fix_output_name(graph_def: tf.compat.v1.GraphDef, outputs: List[str], alias_map: Dict[str, str]):
'''
outputs looks like:
['Identity:0', 'Identity_1:0',
'Identity_2:0', 'Identity_3:0',
'Identity_4:0', 'Identity_5:0',
'Identity_6:0', 'Identity_7:0']
alias_map looks like:
{'Identity:0': 'logit_dislike', 'Identity_1:0': 'logit_like',
'Identity_2:0': 'logit_play', 'Identity_3:0': 'logit_staytime',
'Identity_4:0': 'pred_dislike', 'Identity_5:0': 'pred_like',
'Identity_6:0': 'pred_play', 'Identity_7:0': 'pred_staytime'}
apply alias name inplace so that serving won't need alias mapping
'''
for node in graph_def.node:
tensor_name = node.name + ":0"
if tensor_name in outputs:
node.name = alias_map[tensor_name]
return graph_def
def convert_saved_model_to_graph_def(export_dir):
print("Start to convert saved model to graph def pbtxt", flush=True)
assert(os.path.exists("{}/saved_model.pb".format(export_dir)))
frozen_graph_def, inputs, outputs, alias_map = tf_loader.from_saved_model(
export_dir, input_names = None, output_names = None,
return_tensors_to_rename=True)
# remove trival Identity and control dependency for readability
frozen_graph_def = run_graph_grappler(frozen_graph_def, inputs=inputs, outputs=outputs)
frozen_graph_def = fix_saved_model_control_dependency(frozen_graph_def)
frozen_graph_def = fix_output_name(frozen_graph_def, outputs, alias_map)
graph_def_file = "{}/graph.pbtxt".format(export_dir)
with open(graph_def_file, 'w') as f:
f.write(text_format.MessageToString(frozen_graph_def))
print("Convert saved model to graph def success", flush=True)
----2022.09.28补充--------------
通过阅读tf_loader的远吗,发现在转换成graph的时候,已经做了grappler的优化,取的是constfold, dependency,如果取constfold的话,会导致中间节点被折叠起来,不想被折叠的话,禁止使用constflod优化方法就可以了。但是需要改tf_loader.py的源码(目前没找到能仅仅替换import的模块,里面某个函数的方法)

浙公网安备 33010602011771号