MindSpore Lite的ms模型在安卓手机端进行推理方法

注:本人利用imagenet网络训练ms模型,推到手机端进行图片推理。

加载模型:从文件系统中读取MindSpore Lite模型,并进行模型解析。

model = new Model();
if (!model.loadModel(context, "wangyuan.ms")) {
    Log.e("MS_LITE", "This demo loads model unsuccessfully");
    return false;
}

创建配置上下文:创建配置上下文MSConfig,保存会话所需的一些基本配置参数,用于指导图编译和图执行。主要包括deviceType:设备类型、threadNum:线程数、cpuBindMode:CPU绑定模式、enable_float16:是否优先使用float16算子。

MSConfig msConfig = new MSConfig();
if (!msConfig.init(DeviceType.DT_CPU, 2, CpuBindMode.MID_CPU, true)) {
    Log.e("MS_LITE", "This demo inits context unsuccessfully");
    return false;
}

创建会话:创建LiteSession,并调用init方法将上一步得到MSConfig配置到会话中。

session = new LiteSession();
if (!session.init(msConfig)) {
    Log.e("MS_LITE", "This demo creates session unsuccessfully");
    msConfig.free();
    return false;
}
msConfig.free();

图编译:在图执行前,需要调用LiteSession的compileGraph接口进行图编译,主要进行子图切分、算子选型调度。这部分会耗费较多时间,所以建议LiteSession创建一次,编译一次,多次执行。

if (!session.compileGraph(model)) {
    Log.e("MS_LITE", "This demo compiles graph unsuccessfully");
    model.freeBuffer();
    return false;
}
model.freeBuffer();

输入数据:图执行之前需要向输入Tensor中填充数据(注:本人输入用的是二进制的图片)

List<MSTensor> inputs = session.getInputs();
MSTensor inTensor = inputs.get(0);
byte[] inData = readFileFromAssets(context, "wangyuan_cat.bin");
inTensor.setData(inData);

图执行:使用LiteSession的runGraph进行模型推理。

if (!session.runGraph()) {
    Log.e("MS_LITE", "This demo runs graph unsuccessfully");
    return;
}

获得输出:图执行结束之后,可以通过输出Tensor得到推理结果。

List<String> tensorNames = session.getOutputTensorNames();
Map<String, MSTensor> outputs = session.getOutputMapByTensor();
Set<Map.Entry<String, MSTensor>> entries = outputs.entrySet();
for (String tensorName : tensorNames) {
    MSTensor output = outputs.get(tensorName);
    if (output == null) {
        Log.e("MS_LITE", "Can not find output " + tensorName);
        return;
    }
    float[] results = output.getFloatData();
}


释放内存:无需使用MindSpore Lite推理框架的时候,需要将创建的LiteSession和model进行释放。

private void free() {
    session.free();
    model.free();
}

 

本人选取的数据集是cat_dog数据集,输出结果在logcat里,利用results对应的最高值匹配label,输出相应的结果,结果如下:

WangYuan infer result is Cat !

posted @ 2021-12-30 19:34  MS小白  阅读(168)  评论(0)    收藏  举报