mxnet symbol 打印模型所有中间输出

代码:

import mxnet as mx


def get_output_symbol(symbol):
  """
  Parameters
  ----------
  symbol: Symbol
      Symbol to be visualized.

  """
  import json
  from mxnet.symbol.symbol import Symbol
  if not isinstance(symbol, Symbol):
    raise TypeError("symbol must be Symbol")

  conf = json.loads(symbol.tojson())
  nodes = conf["nodes"]
  heads = set(conf["heads"][0])

  symbols = []
  for i, node in enumerate(nodes):
    op = node["op"]
    if op == "null" and i > 0:
      continue
    if op != "null" or i in heads:
      symbols.append(node['name'])
  return symbols

def debug_model(model):
  # prepare data 准备输入数据
  input_blob=mx.nd.zeros(shape=(1,3,112,112),ctx=mx.cpu())
  db = mx.io.DataBatch(data=(input_blob,))
  # get output symbol 找到特征层,获取输出节点
  symbols = get_output_symbol(model.symbol)
  symbols = [x for x in symbols if x != 'data']
  arg_params, aux_params = model.get_params()
  internals = model.symbol.get_internals()
  outputs = internals.list_outputs()
  symbols_output_name = [x + '_output' for x in symbols]
  symbols_output = [internals[x] for x in symbols_output_name]
  # 重建符号与模型
  group = mx.symbol.Group(symbols_output)
  mod = mx.mod.Module(symbol=group, context=mx.cpu())
  mod.bind(data_shapes=[('data', (1, 3, 112, 112))])  # 绑定输入shape
  mod.set_params(arg_params, aux_params)
  mod.forward(db, is_train=False)
  output = mod.get_outputs()
  output_dict = {k: v.asnumpy() for k, v in zip(symbols, output)}
  # 保存结果
  import os
  from collections import Iterable
  if not os.path.exists('output'):
    os.mkdir('output')
  for k, v in output_dict.items():
    with open('output/{}.txt'.format(k), 'w') as f:
      print('Shape is {}, data type is {}'.format(v.shape, v.dtype), file=f)
      for i, batch in enumerate(v):
        print('Batch {}:'.format(i), file=f)
        for j, channel in enumerate(batch):
          print('{}Channel {}:'.format(' ' * 4, j), file=f)
          if isinstance(channel, Iterable):
            for k, width in enumerate(channel):
              print(' ' * 8, file=f, end='')
              for m, height in enumerate(width):
                print(height, end='  ', file=f)
              print(file=f)
          else:
            print(' ' * 8 + str(channel), file=f)

# 加载与训练模型
def get_model(ctx, image_size, model_str, layer):
  _vec = model_str.split(',')
  assert len(_vec)==2
  prefix = _vec[0]
  epoch = int(_vec[1])
  print('loading',prefix, epoch)
  sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
  all_layers = sym.get_internals()
  sym = all_layers[layer+'_output']

  model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)

  model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
  model.set_params(arg_params, aux_params)
  # 打印输出shape
  arg_shape, out_shape, _ = sym.infer_shape(data=(1, 3, image_size[0], image_size[1]))
  mx.viz.print_summary(sym, {'data': (1, 3, image_size[0], image_size[1])})
  return model
if __name__=='__main__':
  model = get_model(mx.cpu(), (112, 112), 'model-y1-test2/model,0', 'fc1')	
  debug_model(model)

结果

posted @ 2020-08-25 20:01  mengfu188  阅读(141)  评论(0)    收藏  举报