mxnet symbol 打印模型所有中间输出
代码:
import mxnet as mx
def get_output_symbol(symbol):
"""
Parameters
----------
symbol: Symbol
Symbol to be visualized.
"""
import json
from mxnet.symbol.symbol import Symbol
if not isinstance(symbol, Symbol):
raise TypeError("symbol must be Symbol")
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
heads = set(conf["heads"][0])
symbols = []
for i, node in enumerate(nodes):
op = node["op"]
if op == "null" and i > 0:
continue
if op != "null" or i in heads:
symbols.append(node['name'])
return symbols
def debug_model(model):
# prepare data 准备输入数据
input_blob=mx.nd.zeros(shape=(1,3,112,112),ctx=mx.cpu())
db = mx.io.DataBatch(data=(input_blob,))
# get output symbol 找到特征层,获取输出节点
symbols = get_output_symbol(model.symbol)
symbols = [x for x in symbols if x != 'data']
arg_params, aux_params = model.get_params()
internals = model.symbol.get_internals()
outputs = internals.list_outputs()
symbols_output_name = [x + '_output' for x in symbols]
symbols_output = [internals[x] for x in symbols_output_name]
# 重建符号与模型
group = mx.symbol.Group(symbols_output)
mod = mx.mod.Module(symbol=group, context=mx.cpu())
mod.bind(data_shapes=[('data', (1, 3, 112, 112))]) # 绑定输入shape
mod.set_params(arg_params, aux_params)
mod.forward(db, is_train=False)
output = mod.get_outputs()
output_dict = {k: v.asnumpy() for k, v in zip(symbols, output)}
# 保存结果
import os
from collections import Iterable
if not os.path.exists('output'):
os.mkdir('output')
for k, v in output_dict.items():
with open('output/{}.txt'.format(k), 'w') as f:
print('Shape is {}, data type is {}'.format(v.shape, v.dtype), file=f)
for i, batch in enumerate(v):
print('Batch {}:'.format(i), file=f)
for j, channel in enumerate(batch):
print('{}Channel {}:'.format(' ' * 4, j), file=f)
if isinstance(channel, Iterable):
for k, width in enumerate(channel):
print(' ' * 8, file=f, end='')
for m, height in enumerate(width):
print(height, end=' ', file=f)
print(file=f)
else:
print(' ' * 8 + str(channel), file=f)
# 加载与训练模型
def get_model(ctx, image_size, model_str, layer):
_vec = model_str.split(',')
assert len(_vec)==2
prefix = _vec[0]
epoch = int(_vec[1])
print('loading',prefix, epoch)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers[layer+'_output']
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
model.set_params(arg_params, aux_params)
# 打印输出shape
arg_shape, out_shape, _ = sym.infer_shape(data=(1, 3, image_size[0], image_size[1]))
mx.viz.print_summary(sym, {'data': (1, 3, image_size[0], image_size[1])})
return model
if __name__=='__main__':
model = get_model(mx.cpu(), (112, 112), 'model-y1-test2/model,0', 'fc1')
debug_model(model)
结果

浙公网安备 33010602011771号