MindSpore v1.0使用体验之离线模型导出
前言
Mindspore离线模型导出功能主要依靠Export接口实现,主要用于非训练场景下的模型推理。可将训练好的checkpiont文件转换为AIR格式或者ONNX格式的离线模型文件,直接用于其他平台或框架。
Export功能需要基于Acsend910平台,所以用到的mindspore版本为MindSpore-Ascend R1.0。
硬件信息
Intel(R) Core(TM) i7-7700
Linux eulerosv2r7.x86_64
Ascend910 (C75B220SPC1009)
安装Mindspore
为了更好的适配RUN包版本,采用源码编译的方式安装MindSpore。首先配置git相关内容,从gitee仓库克隆代码,并切换到r1.0分支。
git clone https://gitee.com/mindspore/mindspore
git checkout r1.0
进入Mindspore目录下,根据run包安装地址配置环境变量,然后进行编译。值得注意的是编译环境有一定依赖(例如x86上的conda环境,确保Python版本为3.7.5,gcc版本为7.3.0)
export ASCEND_CUSTOM_PATH=/usr/local/Ascend/
sh build.sh -ed -z
完成后会在MindSpore/output目录下生成mindspore的whl包,使用pip安装即可。安装完成后需要对te、topi、hccl库进行刷新,以防出现环境问题。
以run包安装目录/usr/local/Ascend为例:
pip install mindspore_ascend-1.0.0-cp37-cp37m-linux_x86_64.whl
pip uninstall -y te
pip uninstall -y topi
pip uninstall -y hccl
pip install /usr/local/Ascend/fwkacllib/lib64/topi-0.4.0-py3-none-any.whl
pip install /usr/local/Ascend/fwkacllib/lib64/te-0.4.0-py3-none-any.whl
pip install /usr/local/Ascend/fwkacllib/lib64/hccl-0.1.0-py3-none-any.whl
安装完成后进入相应conda环境,尝试导入MindSpore模块,如果出现以下相似内容,且无报错信息,则表示安装成功。
python -c "import mindspore"
模型导出
接下来就是使用export功能进行离线模型的导出,首先来看一下官网对于export这个api的描述:
从图中可以看出,接口位于serialization这个目录下,调用前需要引入模块
from mindspore.train.serialization import export
同时,可以看出调用接口需要定义网络结构(net)、网络输入内容(input)、文件名称(file_name)和模型格式(format)。
网络结构定义根据不同场景,一般从已经写好的网络结构脚本引入。值得说明的是export可以导出随机权重的离线模型也可以导出带有checkpoint参数的离线模型,如果需要进行参数加载,则需要额外导入以下模块:
from mindspore.train.serialization import load_checkpoint, load_param_into_net
模型格式的选择参考官网的说明,可以是AIR、ONNX和Mindir,通常air格式的离线模型用于Ascend310平台,Mindir格式将作为后续的MindSpore的通用格式。
做完以上选择后,给你的文件想个好名字就可以愉快的进行离线模型导出了~他的建议是用网络名称+batchsize信息作为命名格式,还可以适当的加上cpkt的精度信息。
以lenet网络为例,导出代码为:
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import mnist_cfg as cfg
from src.lenet import LeNet5
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\
path where the trained ckpt file')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE)
network = LeNet5(cfg.num_classes)
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
input_data = Tensor(np.random.uniform(0.0, 1.0, size=[1, 1, 32, 32]), ms.float32)
export(network, input_data, file_name='./LeNet5_bs_1.air', file_format="AIR")
运行完成后,可以在当前目录下看到一个以.air后缀结尾的文件,至此大功告成。