Android开发笔记[3]-移植MNIST(TinyMaix)

摘要

移植基于TinyMaix技术的MNIST手写数字识别到Android.

平台

  • Android Studio: Electric Eel | 2022.1.1 Patch 2
  • Gradle:distributionUrl=https://services.gradle.org/distributions/gradle-7.5-bin.zip
  • jvmTarget = '1.8'
  • minSdk 21
  • targetSdk 33
  • compileSdk 33
  • 开发语言:Kotlin,C++,C
  • ndkVersion = '25.2.9519653'

源码地址

[https://gitee.com/qsbye/AndTheStone/tree/mnist]

MNIST手写数字识别简介

[https://zhuanlan.zhihu.com/p/264960142]
MNIST数据库(Modified National Institute of Standards and Technology database)是一个大型数据库的手写数字是通常用于训练各种图像处理系统。该数据库还广泛用于机器学习领域的培训和测试。它是通过“重新混合” NIST原始数据集中的样本而创建的。创作者认为,由于NIST的培训数据集来自美国人口普查局员工,而测试数据集则来自美国 高中学生,这不是非常适合于机器学习实验。此外,将来自NIST的黑白图像归一化以适合28x28像素的边界框并进行抗锯齿处理,从而引入了灰度级。
MNIST手写数字识别模型的主要任务是:输入一张手写数字的图像,然后识别图像中手写的是哪个数字。

该模型的目标明确、任务简单,数据集规范、统一,数据量大小适中,在普通的PC电脑上都能训练和识别,堪称是深度学习领域的“Hello World!”,学习AI的入门必备模型。

TinyMaix技术简介

[https://blog.csdn.net/xusiwei1236/article/details/131408560]
[https://github.com/sipeed/TinyMaix]
TinyMaix是国内sipeed团队开发一个轻量级AI推理框架,官方介绍如下:

TinyMaix 是面向单片机的超轻量级的神经网络推理库,即 TinyML 推理库,可以让你在任意单片机上运行轻量级深度学习模型。
根据官方介绍,在仅有2K RAM的 **Arduino UNO(ATmega328, 32KB Flash, 2KB RAM) **上,都可以基于 TinyMaix 进行手写数字识别。对,你没有看错,2KB RAM 32KB Flash的设备上,都可以使用TinyMaix进行手写数字识别.
TinyMaix 是针对小算力小内存的芯片设计的轻量级推理框架,甚至能在2KB内存的Arduino ATmega328单片机上运行MNIST,对各种架构的单片机都提供了支持和优化,包括 RISC-V、ARM Cortex-M 等。
TinyMaix可以简单理解为一个矩阵和向量计算库,目前已支持如下几种计算硬件:

#define TM_ARCH_CPU         (0) //default, pure cpu compute
#define TM_ARCH_ARM_SIMD    (1) //ARM Cortex M4/M7, etc.
#define TM_ARCH_ARM_NEON    (2) //ARM Cortex A7, etc.
#define TM_ARCH_ARM_MVEI    (3) //ARMv8.1: M55, etc.
#define TM_ARCH_RV32P       (4) //T-head E907, etc.
#define TM_ARCH_RV64V       (5) //T-head C906,C910, etc.
#define TM_ARCH_CSKYV2      (6) //cskyv2 with dsp core
#define TM_ARCH_X86_SSE2    (7) //x86 sse2

对于ARM-Cortex系列MCU,可以支持纯CPU计算和SIMD计算。其中CPU计算部分无特殊依赖(计算代码均使用标准C实现)。SIMD部分,部分计算代码使用了C语言内嵌汇编实现,需要CPU支持相应的汇编指令,才可以正常编译、运行。

JNI技术

JNI是Java Native Interface的缩写,通过使用 Java本地接口书写程序,可以确保代码在不同的平台上方便移植。从Java1.1开始,JNI标准成为java平台的一部分,它允许Java代码和其他语言写的代码进行交互。

Gradle与CMake混合构建Android Native项目(安卓开发中添加C代码)

[https://developer.android.google.cn/studio/projects/add-native-code?hl=zh-cn]
[https://zhuanlan.zhihu.com/p/496125150]
核心步骤如下:

  1. 编写JNI交互代码
  2. C代码文件夹
  3. 编写CMakeLists.txt文件
  4. Gradle中添加对CMakeLists.txt的引用
  • CMakerLists.txt负责指定 要编译的C++ 源码文件的路径 和 要编译生成的 lib 库的库名.
  • build.gradle里添加 原生代码编译项目,让 gradle 知道有个 CMakeLists 要它处理.
android {
    //省略

    //JNI-CPP
    externalNativeBuild {
        cmake {
            path "src/main/jni/CMakeLists.txt"//这里指定CMakeLists.txt(顶层)文件的路径
            version "3.10.2"
        }
    }
}
  • 编译好 lib 库后,需要在 java/kotlin 代码里加载这个原生库
class MainActivity : AppCompatActivity() {
    //省略

    /*开始JNI函数声明*/
    external fun mnistArduino()
    companion object {
        init {
            System.loadLibrary("top-lib")
        }
    }
    /*结束JNI函数声明*/

 /*调用JNI函数*/
   mnistArduino()
 /*end 调用JNI函数*/
    //省略
}
  • 把 c++库里的函数映射成 java 函数,这样其他地方就只用到这个映射好的 java 函数了,函数名有固定的命名规则,由Java的映射函数所在的包名+类名+函数名构成,例如:
Java_cn_qsbye_android_1cam_MainActivity_mnistArduino(JNIEnv* env, jobject obj){}
  • 1cam是多个同名的cn_qsbye_android_cam(main,androidTest,test)的main.

实现

项目目录

android_cam/app/src
.
├── androidTest
│   └── java
│       └── cn
│           └── qsbye
│               └── android_cam
│                   └── ExampleInstrumentedTest.kt
├── main
│   ├── AndroidManifest.xml
│   ├── assets
│   │   ├── yolov8n.bin
│   │   ├── yolov8n.param
│   │   ├── yolov8s.bin
│   │   └── yolov8s.param
│   ├── java
│   │   ├── cn
│   │   │   └── qsbye
│   │   │       └── android_cam
│   │   │           ├── FirstFragment.kt
│   │   │           ├── MainActivity.kt
│   │   │           └── SecondFragment.kt
│   │   └── com
│   │       └── tencent
│   │           └── yolov8ncnn
│   │               ├── MainActivity.java
│   │               └── Yolov8Ncnn.java
│   ├── jni
│   │   ├── CMakeLists.txt
│   │   ├── main.cpp
│   │   └── mnist_arduino
│   │       ├── CMakeLists.txt
│   │       ├── arduino_tinymaix.png
│   │       ├── micros.h
│   │       ├── mnist_arduino.cpp
│   │       ├── mnist_arduino.hpp
│   │       ├── readme.md
│   │       ├── tinymaix.h
│   │       ├── tm_layers.cpp
│   │       ├── tm_model.cpp
│   │       ├── tm_port.h
│   │       └── tm_stat.cpp
│   └── res
│       ├── drawable
│       │   └── ic_launcher_background.xml
│       ├── drawable-v24
│       │   └── ic_launcher_foreground.xml
│       ├── layout
│       │   ├── activity_main.xml
│       │   ├── content_main.xml
│       │   ├── fragment_first.xml
│       │   ├── fragment_second.xml
│       │   └── main.xml
│       ├── menu
│       │   └── menu_main.xml
│       ├── mipmap-anydpi-v26
│       │   ├── ic_launcher.xml
│       │   └── ic_launcher_round.xml
│       ├── mipmap-anydpi-v33
│       │   └── ic_launcher.xml
│       ├── mipmap-hdpi
│       │   ├── ic_launcher.webp
│       │   └── ic_launcher_round.webp
│       ├── mipmap-mdpi
│       │   ├── ic_launcher.webp
│       │   └── ic_launcher_round.webp
│       ├── mipmap-xhdpi
│       │   ├── ic_launcher.webp
│       │   └── ic_launcher_round.webp
│       ├── mipmap-xxhdpi
│       │   ├── ic_launcher.webp
│       │   └── ic_launcher_round.webp
│       ├── mipmap-xxxhdpi
│       │   ├── ic_launcher.webp
│       │   └── ic_launcher_round.webp
│       ├── navigation
│       │   └── nav_graph.xml
│       ├── values
│       │   ├── colors.xml
│       │   ├── dimens.xml
│       │   ├── strings.xml
│       │   └── themes.xml
│       ├── values-land
│       │   └── dimens.xml
│       ├── values-night
│       │   └── themes.xml
│       ├── values-w1240dp
│       │   └── dimens.xml
│       ├── values-w600dp
│       │   └── dimens.xml
│       └── xml
│           ├── backup_rules.xml
│           └── data_extraction_rules.xml
└── test
    └── java
        └── cn
            └── qsbye
                └── android_cam
                    └── ExampleUnitTest.kt

关键代码

build.gradle(app)

android {
    //省略

    //JNI-CPP
    externalNativeBuild {
        cmake {
            path "src/main/jni/CMakeLists.txt"//这里指定CMakeLists.txt文件的路径
            version "3.10.2"
        }
    }
}

jni/CMakeLists.txt

cmake_minimum_required(VERSION 3.10)

project(jni_top)

# 添加子目录
ADD_SUBDIRECTORY(mnist_arduino)

find_library( # Sets the name of the path variable.
              log-lib

              # Specifies the name of the NDK library that
              # you want CMake to locate.
              log )

find_library( # Sets the name of the path variable.
              mnist_arduino

              # Specifies the name of the NDK library that
              # you want CMake to locate.
              mnist-lib )

add_library( # Sets the name of the library.
             top-lib

             # Sets the library as a shared library.
             SHARED

             # Provides a relative path to your source file(s).
             main.cpp)

target_link_libraries( # Specifies the target library.
                       top-lib

                       # Links the target library to the log library
                       # included in the NDK.
                       ${log-lib}
                       mnist-lib )

jni/mnist_arduino/CMakeLists.txt

cmake_minimum_required(VERSION 3.10)

project(mnist_arduino)

find_library( # Sets the name of the path variable.
              log-lib

              # Specifies the name of the NDK library that
              # you want CMake to locate.
              log )

add_library( # Sets the name of the library.
             mnist-lib

             # Sets the library as a STATIC library.
             STATIC

             # Provides a relative path to your source file(s).
             tinymaix.h tm_port.h micros.h tm_layers.cpp tm_model.cpp tm_stat.cpp mnist_arduino.cpp)

#添加hpp文件
target_include_directories(mnist-lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

# Specifies libraries CMake should link to your target library. You
# can link multiple libraries, such as libraries you define in this
# build script, prebuilt third-party libraries, or system libraries.

target_link_libraries( # Specifies the target library.
                       mnist-lib

                       # Links the target library to the log library
                       # included in the NDK.
                       ${log-lib} )

main.cpp

/*
 * @功能:导出C函数为Kotlin函数
*/

#include <jni.h>
#include <android/log.h>
#include "mnist_arduino/mnist_arduino.hpp"

/*开始JNI*/
JNIEnv* Serial::env = nullptr;
jobject Serial::obj = nullptr;

extern "C" JNIEXPORT void JNICALL
Java_cn_qsbye_android_1cam_MainActivity_mnistArduino(JNIEnv* env, jobject obj) {
    mnist_print_test();

    try {
        mnist_arduino();
    } catch (const std::exception& e) {
        // 处理标准异常
        __android_log_print(ANDROID_LOG_ERROR, "JNI", "mnist_arduino() error: %s", e.what());
    } catch (...) {
        // 处理其他异常
        __android_log_print(ANDROID_LOG_ERROR, "JNI", "mnist_arduino() error: unknown exception");
    }
}
/*结束JNI*/

micros.h

//
// Created by workSpace on 2023/8/10.
//

#ifndef ANDROID_CAM_MICROS_H
#define ANDROID_CAM_MICROS_H

#include <time.h>

#ifndef micros
#pragma once
long long micros(void){
    struct timespec ts;
    clock_gettime(CLOCK_MONOTONIC, &ts);
    long long microseconds = ts.tv_sec * 1000000LL + ts.tv_nsec / 1000;
    //std::cout << "Current micros time: " << microseconds << std::endl;
    return microseconds;
}
#endif

#endif //ANDROID_CAM_MICROS_H

tinymaix.h

/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef __TINYMAIX_H
#define __TINYMAIX_H

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>


#define  TM_MDL_INT8    0
#define  TM_MDL_INT16   1
#define  TM_MDL_FP32    2
#include "tm_port.h"

/******************************* MARCO ************************************/
#define TM_MDL_MAGIC 'XIAM'     //mdl magic sign
#define TM_ALIGN_SIZE   (8)     //8 byte align
#define TM_MATP(mat,y,x,ch) ((mat)->data + ((y)*(mat)->w + (x))*(mat)->c + (ch))
                                //HWC
#if   TM_MDL_TYPE == TM_MDL_INT8
    typedef int8_t  mtype_t;    //mat data type
    typedef int8_t  wtype_t;    //weight data type
    typedef int32_t btype_t;    //bias data type
    typedef int32_t sumtype_t;  //sum data type 
    typedef int32_t zptype_t;   //zeropoint data type 
    #define UINT2INT_SHIFT (0)
#elif TM_MDL_TYPE == TM_MDL_INT16
    typedef int16_t mtype_t;    //mat data type
    typedef int16_t wtype_t;    //weight data type
    typedef int32_t btype_t;    //bias data type
    typedef int32_t sumtype_t;  //sum data type 
    typedef int32_t zptype_t;   //zeropoint data type
    #define UINT2INT_SHIFT (8)
#elif TM_MDL_TYPE == TM_MDL_FP32
    typedef float   mtype_t;    //mat data type
    typedef float   wtype_t;    //weight data type
    typedef float   btype_t;    //bias data type
    typedef float   sumtype_t;  //sum data type 
    typedef float   zptype_t;   //zeropoint data type 
#else 
    #error "Not support this MDL_TYPE!"
#endif

typedef float sctype_t;
#define TM_FASTSCALE_SHIFT (8)

/******************************* ENUM ************************************/
typedef enum{
    TM_OK = 0,
    TM_ERR= 1,
    TM_ERR_MAGIC     = 2,
    TM_ERR_UNSUPPORT = 3,
    TM_ERR_OOM       = 4,
    TM_ERR_LAYERTYPE = 5,
    TM_ERR_DIMS      = 6,
    TM_ERR_TODO      = 7,
    TM_ERR_MDLTYPE   = 8,
    TM_ERR_KSIZE     = 9,
}tm_err_t;

typedef enum{
    TML_CONV2D    = 0,
    TML_GAP       = 1,
    TML_FC        = 2,
    TML_SOFTMAX   = 3,
    TML_RESHAPE   = 4,
    TML_DWCONV2D  = 5,
    TML_MAXCNT    ,
}tm_layer_type_t;

typedef enum{
    TM_PAD_VALID  = 0,
    TM_PAD_SAME   = 1,
}tm_pad_type_t;

typedef enum{
    TM_ACT_NONE   = 0,
    TM_ACT_RELU   = 1,
    TM_ACT_RELU1  = 2,
    TM_ACT_RELU6  = 3,
    TM_ACT_TANH   = 4,
    TM_ACT_SIGNBIT= 5,
    TM_ACT_MAXCNT ,
}tm_act_type_t;


typedef enum {
    TMPP_NONE      = 0,
    TMPP_FP2INT    = 1,  //user own fp buf -> int input buf
    TMPP_UINT2INT  = 2,  //int8: cvt in place; int16: can't cvt in place
    TMPP_UINT2FP01 = 3,  // u8/255.0
    TMPP_UINT2FPN11= 4,  // (u8-128)/128  
    TMPP_MAXCNT,
}tm_pp_t;

/******************************* STRUCT ************************************/
//mdlbin in flash
typedef struct{
    uint32_t magic;         //"MAIX"
    uint8_t  mdl_type;      //0 int8, 1 int16, 2 fp32,
    uint8_t  out_deq;       //0 don't dequant out; 1 dequant out
    uint16_t input_cnt;     //only support 1 yet
    uint16_t output_cnt;    //only support 1 yet
    uint16_t layer_cnt;     
    uint32_t buf_size;      //main buf size for middle result
    uint32_t sub_size;      //sub buf size for middle result
    uint16_t in_dims[4];    //0:dims; 1:dim0; 2:dim1; 3:dim2
    uint16_t out_dims[4];
    uint8_t  reserve[28];   //reserve for future
    uint8_t  layers_body[0];//oft 64 here
}tm_mdlbin_t;

//mdl meta data in ram
typedef struct{
    tm_mdlbin_t* b;         //bin
    void*    cb;            //Layer callback
    uint8_t* buf;           //main buf addr
    uint8_t* subbuf;        //sub buf addr
    uint16_t main_alloc;    //is main buf alloc or static
    uint16_t layer_i;       //current layer index
    uint8_t* layer_body;    //current layer body addr
}tm_mdl_t;

//dims==3, hwc
//dims==2, 1wc
//dims==1, 11c
typedef struct{
    uint16_t dims;
    uint16_t h;
    uint16_t w;
    uint16_t c;
    union {
        mtype_t* data;
        float*   dataf;
    };
}tm_mat_t;

/******************************* LAYER STRUCT ************************************/
typedef struct{             //48byte
    uint16_t type;          //layer type
    uint16_t is_out;        //is output
    uint32_t size;          //8 byte align size for this layer
    uint32_t in_oft;        //input  oft in main buf
    uint32_t out_oft;       //output oft in main buf
    uint16_t in_dims[4];    //0:dims; 1:dim0; 2:dim1; 3:dim2
    uint16_t out_dims[4];
                            //following unit not used in fp32 mode
    sctype_t in_s;          //input scale, 
    zptype_t in_zp;         //input zeropoint
    sctype_t out_s;         //output scale
    zptype_t out_zp;        //output zeropoint
    //note: real = scale*(q-zeropoint)
}tml_head_t;

typedef struct{
    tml_head_t h;

    uint8_t  kernel_w;
    uint8_t  kernel_h;
    uint8_t  stride_w;
    uint8_t  stride_h;
    
    uint8_t  dilation_w;
    uint8_t  dilation_h;
    uint16_t  act;          //0 none, 1 relu, 2 relu1, 3 relu6, 4 tanh, 5 sign_bit
    
    uint8_t  pad[4];        //top,bottom,left,right

    uint32_t depth_mul;     //depth_multiplier: if conv2d,=0; else: >=1
    uint32_t reserve;       //for 8byte align
    
    uint32_t ws_oft;        //weight scale oft from this layer start 
                            //skip bias scale: bias_scale = weight_scale*in_scale
    uint32_t w_oft;         //weight oft from this layer start
    uint32_t b_oft;         //bias oft from this layer start 
    //note: bias[c] = bias[c] + (-out_zp)*sum(w[c*chi*maxk:(c+1)*chi*maxk])
    //      fused in advance (when convert model)
}tml_conv2d_dw_t;  //compatible with conv2d and dwconv2d

typedef struct{
    tml_head_t h;
}tml_gap_t;

typedef struct{
    tml_head_t h;

    uint32_t ws_oft;        //weight scale oft from this layer start 
    uint32_t w_oft;         //weight oft from this layer start
    uint32_t b_oft;         //bias oft from this layer start
    uint32_t reserve;       //for 8byte align
}tml_fc_t;

typedef struct{
    tml_head_t h;
}tml_softmax_t;

typedef struct{
    tml_head_t h;
}tml_reshape_t;

typedef struct{
    tml_head_t h;

    uint8_t  kernel_w;
    uint8_t  kernel_h;
    uint8_t  stride_w;
    uint8_t  stride_h;
    
    uint8_t  dilation_w;
    uint8_t  dilation_h;
    uint16_t  act;          //0 none, 1 relu, 2 relu1, 3 relu6, 4 tanh, 5 sign_bit
    
    uint8_t  pad[4];        //top,bottom,left,right


    
    uint32_t ws_oft;        //weight scale oft from this layer start 
                            //skip bias scale: bias_scale = weight_scale*in_scale
    uint32_t w_oft;         //weight oft from this layer start
    uint32_t b_oft;         //bias oft from this layer start 
    //note: bias[c] = bias[c] + (-out_zp)*sum(w[c*chi*maxk:(c+1)*chi*maxk])
    //      fused in advance (when convert model)
}tml_dwconv2d_t;



/******************************* TYPE ************************************/
typedef tm_err_t (*tml_stat_t)(tml_head_t* layer, tm_mat_t* in, tm_mat_t* out);
typedef tm_err_t (*tm_cb_t)(tm_mdl_t* mdl, tml_head_t* lh);


/******************************* GLOBAL VARIABLE ************************************/


/******************************* MODEL FUNCTION ************************************/
tm_err_t tm_load  (tm_mdl_t* mdl, const uint8_t* bin, uint8_t*buf, tm_cb_t cb, tm_mat_t* in);   //load model
void     tm_unload(tm_mdl_t* mdl);                                      //remove model
tm_err_t tm_preprocess(tm_mdl_t* mdl, tm_pp_t pp_type, tm_mat_t* in, tm_mat_t* out);            //preprocess input data
tm_err_t tm_run   (tm_mdl_t* mdl, tm_mat_t* in, tm_mat_t* out);         //run model


/******************************* LAYER FUNCTION ************************************/
tm_err_t tml_conv2d_dwconv2d(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
    int kw, int kh, int sx, int sy, int dx, int dy, int act, \
    int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
    sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
tm_err_t tml_gap(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
tm_err_t tml_fc(tm_mat_t* in, tm_mat_t* out,  wtype_t* w, btype_t* b, \
    sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
tm_err_t tml_softmax(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);
tm_err_t tml_reshape(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp);


/******************************* STAT FUNCTION ************************************/
//#define  TM_ENABLE_STAT
#if TM_ENABLE_STAT
tm_err_t tm_stat(tm_mdlbin_t* mdl);                    //stat model
#endif


/******************************* UTILS  ************************************/
#define TML_GET_INPUT(mdl,lh)   ((mtype_t*)((mdl)->buf + (lh)->in_oft))
#define TML_GET_OUTPUT(mdl,lh)  ((mtype_t*)((mdl)->buf + (lh)->out_oft))
#define TML_DEQUANT(lh, x)       (((sumtype_t)(x)-((lh)->out_zp))*((lh)->out_s))

#endif 

tm_port.h

/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef __TM_PORT_H
#define __TM_PORT_H

#include <time.h>
#include <iostream>
//#include <micros.h>

//#include <Arduino.h>

#define TM_ARCH_OPT0        (0) //default
#define TM_ARCH_OPT1        (1)
#define TM_ARCH_OPT2        (2)
#define TM_ARCH_ARM_SIMD    (3) //seems 32bit simd not faster enough
#define TM_ARCH_ARM_NEON    (4) //TODO
#define TM_ARCH_RV32P       (5) //
#define TM_ARCH_RV64V       (6)

/******************************* PORT CONFIG  ************************************/
#define TM_ARCH         TM_ARCH_ARM_SIMD
#define TM_MDL_TYPE     TM_MDL_INT8 
#define TM_FASTSCALE    (1)         //enable if your chip don't have FPU, may speed up 1/3, but decrease accuracy
#define TM_ENABLE_STAT  (0)         //enable mdl stat functions
#define TM_MAX_KSIZE    (5*5)       //max kernel_size

#define tm_malloc(x)   malloc(x)
#define tm_free(x)     free(x)

#define TM_PRINTF(...) printf(__VA_ARGS__);
#define TM_DBG(...)    TM_PRINTF("###L%d: ",__LINE__);TM_PRINTF(__VA_ARGS__);
#define TM_DBGL()      printf(__LINE__);

#define  TM_GET_US()       1
//#define TM_GET_US() micros()

#define TM_DBGT_INIT()     uint32_t _start,_finish;uint32_t _time;_start=TM_GET_US();
#define TM_DBGT_START()    _start=TM_GET_US();
#define TM_DBGT(x)         {_finish=TM_GET_US();\
                            _time = (_finish-_start);\
                            _start=TM_GET_US();\
}

/******************************* OPS CONFIG  ************************************/

#endif

mnist_arduino.hpp

#include <jni.h>
#include <android/log.h>
#include <iostream>
#include <sstream>

#ifndef __MNIST_ARDUINO_H
#define __MNIST_ARDUINO_H

class Printable {
public:
    virtual std::string toString() const = 0;
};

template<typename T>
class PrintableImpl : public Printable {
private:
    const T& value;

public:
    PrintableImpl(const T& val) : value(val) {}

    std::string toString() const override {
        std::ostringstream oss;
        if (value == 0) {
            oss << "NULL";
        } else {
            oss << value;
        }
        return oss.str();
    }
};

class Serial {
public:
    static JNIEnv* env;
    static jobject obj;

    template<typename T>
    static void print(const T& printable) {
        PrintableImpl<T> p(printable);
        std::string content = p.toString();
        const char* str = content.c_str();
        //jstring jstr = env->NewStringUTF(str);
        //const char* utfStr = env->GetStringUTFChars(jstr, nullptr);
        //__android_log_print(ANDROID_LOG_ERROR, "CPP", "%s", utfStr);
        __android_log_print(ANDROID_LOG_ERROR, "CPP", "%s", str);
        //env->ReleaseStringUTFChars(jstr, utfStr);
        //env->DeleteLocalRef(jstr);
    }
};

void mnist_arduino(void);
void mnist_print_test(void);
#endif

mnist_arduino.cpp

/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mnist_arduino.hpp"

#include "stdio.h"
#include "tinymaix.h"

#include <jni.h>
#include <android/log.h>
#include <iostream>
#include <sstream>

#include <stdint.h>
#define MDL_BUF_LEN (960)
#define LBUF_LEN (360)
const uint8_t mdl_data[920] ={\
  0x4d, 0x41, 0x49, 0x58, 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x06, 0x00, 0xc0, 0x03, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x1c, 0x00, 0x1c, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 
  0x01, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x03, 0x00, 0x00, 
  0x03, 0x00, 0x1c, 0x00, 0x1c, 0x00, 0x01, 0x00, 0x03, 0x00, 0x0d, 0x00, 0x0d, 0x00, 0x01, 0x00, 
  0x81, 0x80, 0x80, 0x3b, 0x80, 0xff, 0xff, 0xff, 0x68, 0x9d, 0x6b, 0x3c, 0x80, 0xff, 0xff, 0xff, 
  0x03, 0x03, 0x02, 0x02, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 
  0xb8, 0xfc, 0x51, 0x3c, 0x00, 0x00, 0x00, 0x00, 0x50, 0x7f, 0x30, 0xed, 0x10, 0x17, 0x9b, 0xc8, 
  0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc4, 0x3d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x90, 0x00, 0x00, 0x00, 0x10, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x03, 0x00, 0x0d, 0x00, 0x0d, 0x00, 0x01, 0x00, 0x03, 0x00, 0x06, 0x00, 0x06, 0x00, 0x03, 0x00, 
  0x68, 0x9d, 0x6b, 0x3c, 0x80, 0xff, 0xff, 0xff, 0x40, 0x47, 0x57, 0x3c, 0x80, 0xff, 0xff, 0xff, 
  0x03, 0x03, 0x02, 0x02, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 
  0xa6, 0x15, 0xf4, 0x3b, 0x51, 0x1d, 0x45, 0x3b, 0x98, 0x06, 0xcf, 0x3b, 0x00, 0x00, 0x00, 0x00, 
  0xea, 0x00, 0x03, 0x9a, 0x81, 0xd0, 0xfb, 0x09, 0x09, 0x13, 0x47, 0x5a, 0xda, 0x4c, 0x7f, 0xf9, 
  0x44, 0x56, 0x23, 0x16, 0x11, 0x24, 0x10, 0x09, 0x81, 0xc5, 0xd7, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x08, 0x90, 0xff, 0xff, 0xc5, 0xfe, 0x00, 0x00, 0x92, 0xe3, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x28, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xa8, 0x03, 0x00, 0x00, 
  0x03, 0x00, 0x06, 0x00, 0x06, 0x00, 0x03, 0x00, 0x03, 0x00, 0x02, 0x00, 0x02, 0x00, 0x06, 0x00, 
  0x40, 0x47, 0x57, 0x3c, 0x80, 0xff, 0xff, 0xff, 0xfe, 0xcd, 0x65, 0x3d, 0x80, 0xff, 0xff, 0xff, 
  0x03, 0x03, 0x02, 0x02, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 0x10, 0x01, 0x00, 0x00, 
  0x18, 0xa6, 0x35, 0x3d, 0x64, 0xd6, 0xfd, 0x3c, 0x16, 0x84, 0x37, 0x3d, 0x92, 0x5a, 0x07, 0x3d, 
  0xb7, 0xfa, 0xb2, 0x3c, 0x7e, 0xd4, 0x13, 0x3d, 0xb0, 0x2a, 0xc4, 0x26, 0x2d, 0x00, 0xad, 0x81, 
  0xe9, 0xfc, 0x0d, 0xf3, 0x10, 0x08, 0xec, 0xfa, 0x04, 0x06, 0x0d, 0x2f, 0xde, 0x21, 0x0c, 0xef, 
  0x1e, 0xe7, 0x06, 0x81, 0xcd, 0xc2, 0xef, 0xf8, 0x07, 0x04, 0xd2, 0xe2, 0x0f, 0x39, 0x40, 0xed, 
  0x05, 0x04, 0xfa, 0xd2, 0xd1, 0xad, 0xdb, 0x15, 0x11, 0xdf, 0xf0, 0x0f, 0xf4, 0x0f, 0xf8, 0x14, 
  0xde, 0x0e, 0x08, 0x17, 0x7f, 0x3a, 0x2a, 0x13, 0x0c, 0x02, 0xfd, 0x28, 0x0f, 0x21, 0xec, 0xcd, 
  0xd3, 0x0a, 0x25, 0x33, 0x12, 0xf1, 0x45, 0x18, 0x1a, 0xc4, 0xc1, 0xd9, 0xe8, 0x07, 0x01, 0x81, 
  0x00, 0x0c, 0xde, 0xb6, 0x04, 0xd4, 0x12, 0x25, 0xf0, 0x43, 0x05, 0xd4, 0xd4, 0x09, 0xa7, 0x30, 
  0x36, 0xb1, 0xef, 0x3b, 0x2b, 0xcf, 0x81, 0x14, 0x0d, 0xe9, 0xbc, 0xf9, 0x03, 0x29, 0x5c, 0x57, 
  0xb6, 0xa6, 0xd0, 0xff, 0x22, 0x02, 0xd8, 0x04, 0x16, 0xff, 0x08, 0xf1, 0xb4, 0xb9, 0x0a, 0x00, 
  0x14, 0x3d, 0xcc, 0xcc, 0xde, 0xd3, 0xca, 0xdd, 0x41, 0x35, 0xf1, 0x9a, 0xa3, 0xc3, 0xe2, 0x09, 
  0x27, 0xd1, 0xdd, 0x11, 0xe4, 0x81, 0xd5, 0xd6, 0xb5, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x25, 0xc3, 0xff, 0xff, 0x03, 0x81, 0xff, 0xff, 0x00, 0xb7, 0x00, 0x00, 0x46, 0x73, 0xff, 0xff, 
  0xff, 0x9b, 0xff, 0xff, 0x2d, 0x03, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 
  0xa8, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x02, 0x00, 0x02, 0x00, 0x06, 0x00, 
  0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x06, 0x00, 0xfe, 0xcd, 0x65, 0x3d, 0x80, 0xff, 0xff, 0xff, 
  0x11, 0xef, 0xc5, 0x3c, 0x80, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xd0, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0xb0, 0x03, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x06, 0x00, 
  0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x0a, 0x00, 0x11, 0xef, 0xc5, 0x3c, 0x80, 0xff, 0xff, 0xff, 
  0x4f, 0x75, 0x1d, 0x3e, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 
  0xa8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x85, 0xba, 0x05, 0x3d, 0x00, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
  0xb9, 0x0f, 0xcc, 0x08, 0x38, 0xf3, 0xc8, 0x1d, 0xd8, 0xa0, 0xd9, 0x44, 0x13, 0x9b, 0x17, 0x04, 
  0x2f, 0xe9, 0xfa, 0xea, 0x27, 0xf6, 0x28, 0x90, 0x28, 0xe6, 0x09, 0xee, 0x81, 0x31, 0xe6, 0xf8, 
  0x1b, 0x19, 0x0c, 0xda, 0xd2, 0xc7, 0xb3, 0x2f, 0xc9, 0x23, 0x21, 0xfd, 0x25, 0xae, 0x19, 0xef, 
  0xf3, 0x3a, 0xcf, 0x08, 0xe0, 0xc5, 0x29, 0x26, 0x0c, 0xd9, 0xb8, 0xdb, 0x00, 0x00, 0x00, 0x00, 
  0x8c, 0xdf, 0xff, 0xff, 0x0d, 0xc7, 0xff, 0xff, 0xb3, 0xf2, 0xff, 0xff, 0x40, 0xe2, 0xff, 0xff, 
  0x84, 0xd4, 0xff, 0xff, 0xac, 0xf3, 0xff, 0xff, 0x86, 0xb5, 0xff, 0xff, 0x19, 0xf6, 0xff, 0xff, 
  0x75, 0xdc, 0xff, 0xff, 0xfa, 0xe3, 0xff, 0xff, 0x03, 0x00, 0x01, 0x00, 0x30, 0x00, 0x00, 0x00, 
  0xb0, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x0a, 0x00, 
  0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x0a, 0x00, 0x4f, 0x75, 0x1d, 0x3e, 0x3f, 0x00, 0x00, 0x00, 
  0x00, 0x00, 0x80, 0x3b, 0x80, 0xff, 0xff, 0xff, 
};


const uint8_t mnist_pic[28*28] ={
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,116,125,171,255,255,150, 93,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,169,253,253,253,253,253,253,218, 30,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,169,253,253,253,213,142,176,253,253,122,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0, 52,250,253,210, 32, 12,  0,  6,206,253,140,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0, 77,251,210, 25,  0,  0,  0,122,248,253, 65,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0, 31, 18,  0,  0,  0,  0,209,253,253, 65,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,117,247,253,198, 10,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 76,247,253,231, 63,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,128,253,253,144,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,176,246,253,159, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 25,234,253,233, 35,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,198,253,253,141,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0, 78,248,253,189, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0, 19,200,253,253,141,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,134,253,253,173, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,248,253,253, 25,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,248,253,253, 43, 20, 20, 20, 20,  5,  0,  5, 20, 20, 37,150,150,150,147, 10,  0,
  0,  0,  0,  0,  0,  0,  0,  0,248,253,253,253,253,253,253,253,168,143,166,253,253,253,253,253,253,253,123,  0,
  0,  0,  0,  0,  0,  0,  0,  0,174,253,253,253,253,253,253,253,253,253,253,253,249,247,247,169,117,117, 57,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,118,123,123,123,166,253,253,253,155,123,123, 41,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
};

tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh)
{   //dump middle result
    //return TM_OK;
    int h = lh->out_dims[1];
    int w = lh->out_dims[2];
    int ch= lh->out_dims[3];
    mtype_t* output = TML_GET_OUTPUT(mdl, lh);
    Serial::print("Layer ");Serial::print(mdl->layer_i);Serial::print(" callback ========\n");
    for(int y=0; y<h; y++){
       Serial::print("[");
        for(int x=0; x<w; x++){
            Serial::print("[");
            for(int c=0; c<ch; c++){
                Serial::print(output[(y*w+x)*ch+c]);Serial::print(",");
            }
            Serial::print("],");
        }
        Serial::print("],\n");
    }
    Serial::print("\n");
    return TM_OK;
}

static void parse_output(tm_mat_t* outs)
{
    tm_mat_t out = outs[0];
    float* data  = out.dataf;
    float maxp = 0;
    int maxi = -1;
    for(int i=0; i<10; i++){
        TM_PRINTF("%d: %.3f\n", i, data[i]);
        Serial::print(i);Serial::print(": ");Serial::print(int(data[i]*100));Serial::print("\n");
        if(data[i] > maxp) {
            maxi = i;
            maxp = data[i];
        }
    }
    Serial::print("### Predict output is: Number ");Serial::print(maxi);Serial::print(", prob=");Serial::print(int(maxp*100));
    Serial::print("\n");
    return;
}


#define IMG_L   (28)
#define IMG_CH  (1)
#define CLASS_N (10)

static uint8_t mdl_buf[MDL_BUF_LEN];

void mnist_arduino(void) {
    //Serial.begin(115200);
    //printf_begin();
    TM_DBGT_INIT();
    Serial::print("mnist_arduino()\n");
    tm_mdl_t mdl;

    for(int i=0; i<28*28; i++){
        //Serial::print(&mnist_pic[i])>>4);
        if(i%28==27) Serial::print("\n");
    }

    tm_mat_t in_uint8 = {3,IMG_L,IMG_L,IMG_CH, (mtype_t*)mnist_pic};
    tm_mat_t in = {3,IMG_L,IMG_L,IMG_CH, NULL};
    tm_mat_t outs[1];
    tm_err_t res;
    //tm_stat((tm_mdlbin_t*)mdl_data);

    res = tm_load(&mdl, mdl_data, mdl_buf, layer_cb, &in);
    if(res != TM_OK) {
        TM_PRINTF("tm model load err %d\n", res);
        return ;
    }

    res = tm_preprocess(&mdl, TMPP_UINT2INT, &in_uint8, &in); 

    TM_DBGT_START();
    res = tm_run(&mdl, &in, outs);
    TM_DBGT("tm_run");
    if(res==TM_OK) parse_output(outs);  
    else TM_PRINTF("tm run error: %d\n", res);
    tm_unload(&mdl);                  
    return ;
}

void mnist_print_test(void){
     __android_log_print(ANDROID_LOG_ERROR, "JNI", "mnist_print_test()");
     Serial::print("mnist demo\n");

}

tm_stat.cpp

/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tinymaix.h"

#if TM_ENABLE_STAT


static const char* tml_str_tbl[TML_MAXCNT] = {
    "Conv2D",   /*TML_CONV2D  = 0,*/
    "GAP",      /*TML_GAP     = 1,*/
    "FC",       /*TML_FC      = 2,*/
    "Softmax",  /*TML_SOFTMAX = 3,*/
    "Reshape",  /*TML_RESHAPE = 4,*/
    "DWConv2D", /*TML_DWCONV2D= 5,*/
};

static const int tml_headsize_tbl[TML_MAXCNT] = {
    sizeof(tml_conv2d_dw_t),
    sizeof(tml_gap_t),
    sizeof(tml_fc_t),
    sizeof(tml_softmax_t),
    sizeof(tml_reshape_t),
    sizeof(tml_conv2d_dw_t),
};

tm_err_t tm_stat(tm_mdlbin_t* b)
{   
    printf("================================ model stat ================================\n");
    printf("mdl_type=%d (0 int8, 1 int16, 2 fp32)\n", b->mdl_type);
    printf("out_deq=%d \n", b->out_deq);
    printf("input_cnt=%d, output_cnt=%d, layer_cnt=%d\n", b->input_cnt, b->output_cnt, b->layer_cnt);
    uint16_t* idim = b->in_dims;
    printf("input %ddims: (%d, %d, %d)\n", idim[0],idim[1],idim[2],idim[3]);
    uint16_t* odim = b->out_dims;
    printf("output %ddims: (%d, %d, %d)\n", odim[0],odim[1],odim[2],odim[3]);
    //printf("model param bin addr: 0x%x\n", (uint32_t)(b->layers_body));
    printf("main buf size %d; sub buf size %d\n", \
        b->buf_size,b->sub_size);

    printf("//Note: PARAM is layer param size, include align padding\r\n\r\n");
    printf("Idx\tLayer\t         outshape\tinoft\toutoft\tPARAM\tMEMOUT OPS\n");
    printf("---\tInput    \t%3d,%3d,%3d\t-   \t0    \t0 \t%ld \t0\n",\
        idim[1],idim[2],idim[3], idim[1]*idim[2]*idim[3]*sizeof(mtype_t));
    //      000  Input    -     224,224,3  0x40001234 0x40004000 100000 500000 200000
    //printf("000  Input    -     %3d,%3d,%d  0x%08x   0x%08x     %6d %6d %6d\n",) 
    int sum_param = 0;
    int sum_ops   = 0;
    uint8_t*layer_body  = (uint8_t*)b->layers_body;
    int layer_i;
    for(layer_i = 0; layer_i < b->layer_cnt; layer_i++){
        tml_head_t* h = (tml_head_t*)(layer_body);
        TM_DBG("body oft = %d\n", (uint32_t)h - (uint32_t)(b));
        TM_DBG("type=%d, is_out=%d, size=%d, in_oft=%d, out_oft=%d, in_dims=[%d,%d,%d,%d], out_dims=[%d,%d,%d,%d], in_s=%.3f, in_zp=%d, out_s=%.3f, out_zp=%d\n",\
                h->type,h->is_out,h->size,h->in_oft,h->out_oft,\
                h->in_dims[0],h->in_dims[1],h->in_dims[2],h->in_dims[3],\
                h->out_dims[0],h->out_dims[1],h->out_dims[2],h->out_dims[3],\
                h->in_s,(int)(h->in_zp),h->out_s,(int)(h->out_zp));
        if(h->type < TML_MAXCNT) {
            int memout = h->out_dims[1]*h->out_dims[2]*h->out_dims[3];
            sum_param += (h->size - tml_headsize_tbl[h->type]);
            int ops = 0;
            switch(h->type){
            case TML_CONV2D: {
                tml_conv2d_dw_t* l = (tml_conv2d_dw_t*)(layer_body);
                ops = memout*(l->kernel_w)*(l->kernel_h)*(h->in_dims[3]);   //MAC as ops
                TM_DBG("Conv2d: kw=%d, kh=%d, sw=%d, sh=%d, dw=%d, dh=%d, act=%d, pad=[%d,%d,%d,%d], dmul=%d, ws_oft=%d, w_oft=%d, b_oft=%d\n",\
                    l->kernel_w, l->kernel_h, l->stride_w, l->stride_h, l->dilation_w, l->dilation_h, \
                    l->act, l->pad[0], l->pad[1], l->pad[2], l->pad[3], l->depth_mul, \
                    l->ws_oft, l->w_oft, l->b_oft);
                break;}
            case TML_GAP:
                ops = (h->in_dims[1])*(h->in_dims[2])*(h->in_dims[3]);  //SUM as ops
                break;
            case TML_FC: {
                tml_fc_t* l = (tml_fc_t*)(layer_body);
                ops = (h->out_dims[3])*(h->in_dims[3]);         //MAC as ops
                TM_DBG("FC: ws_oft=%d, w_oft=%d, b_oft=%d\n",\
                    l->ws_oft, l->w_oft, l->b_oft);
                break;}
            case TML_SOFTMAX:
                ops = 6*(h->out_dims[3]);                       //mixed
                break;
            case TML_DWCONV2D: {
                tml_conv2d_dw_t* l = (tml_conv2d_dw_t*)(layer_body);
                ops = memout*(l->kernel_w)*(l->kernel_h)*1;   //MAC as ops
                TM_DBG("DWConv2d: kw=%d, kh=%d, sw=%d, sh=%d, dw=%d, dh=%d, act=%d, pad=[%d,%d,%d,%d], dmul=%d, ws_oft=%d, w_oft=%d, b_oft=%d\n",\
                    l->kernel_w, l->kernel_h, l->stride_w, l->stride_h, l->dilation_w, l->dilation_h, \
                    l->act, l->pad[0], l->pad[1], l->pad[2], l->pad[3], l->depth_mul,\
                    l->ws_oft, l->w_oft, l->b_oft);
                break;}
            default:
                ops = 0;
                break;
            }
            sum_ops += ops;
            printf("%03d\t%s      \t%3d,%3d,%3d\t%d\t%d\t%d\t%ld\t", layer_i, tml_str_tbl[h->type], \
                h->out_dims[1], h->out_dims[2], h->out_dims[3], \
                h->in_oft, h->out_oft, h->size - tml_headsize_tbl[h->type], \
                memout*sizeof(mtype_t));
            printf("%d\r\n", ops);
        } else {
            return TM_ERR_LAYERTYPE;
        }
        layer_body += (h->size);
    }
    printf("\r\nTotal param ~%.1f KB, OPS ~%.2f MOPS, buffer %.1f KB\r\n\r\n", \
        sum_param/1024.0, sum_ops/1000000.0, (b->buf_size + b->sub_size)/1024.0);
    return TM_OK;
} 


#endif

tm_layers.cpp

/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tinymaix.h"
#include <float.h>
#include <math.h>

/*************************** TML_CONV2D **********************************/
static uint32_t k_oft[TM_MAX_KSIZE]; 
//for valid or kernel in valid part, use fast method
tm_err_t __attribute__((weak)) tml_conv2d_dwconv2d(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
    int kw, int kh, int sx, int sy, int dx, int dy, int act, \
    int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
    sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp) //kernel: (cho, chi, h, w)
{   TM_DBGT_INIT();
    int pad_flag = (pad_top != 0 ||pad_bottom != 0 ||pad_left != 0 ||pad_right != 0);
    if(dx!=1 || dy!= 1) return TM_ERR_TODO;   
    if(act >= TM_ACT_MAXCNT) return TM_ERR_UNSUPPORT;   
    int maxk = kw*kh;
    if(maxk>TM_MAX_KSIZE) return TM_ERR_KSIZE;
    int chi  = in->c; 
    int cho  = out->c;

    int oft = 0;
    int idx = 0;
    for(int y=0; y<kh; y++){
        for(int x=0; x<kw; x++){
            k_oft[idx] = oft;
            idx += 1;
            oft += chi;
        }
        oft += (in->w - kw)*chi; 
    }

    chi  = dmul ? 1 : in->c; // dmul>=1 indicate depthwise; dummy chi for dwconv compatible
    mtype_t* outp;
    int slow_flag = 0; //same pad part is slow
    sumtype_t sum_pad = 0;
    int32_t os = (1<<TM_FASTSCALE_SHIFT)/out_s;
    for (int c = 0; c < out->c; c++) {//TM_DBGL();
        int32_t scale=1.0/ws[c]/in_s;
        outp = out->data + c;
        sum_pad = 0;
        if(pad_flag && TM_MDL_TYPE != TM_MDL_FP32) { // fix pad sum fuse
            for (int k = c*chi*maxk; k < (c+1)*chi*maxk; k++) 
                sum_pad += in_zp*((wtype_t*)w)[k];
        }
        for (int y = 0; y < out->h; y++) {//TM_DBGL();
            int src_y0 = sy*y - pad_top;
            int src_y1 = src_y0+kh;
            for (int x = 0; x < out->w; x++) {
                int src_x0 = sx*x - pad_left;
                int src_x1 = src_x0+kw;
                sumtype_t sum = 0;
                slow_flag = ((src_y0<0)+(src_x0<0)+(src_y1>in->h)+(src_x1>in->w));
                if(!slow_flag) { //valid or same valid part
                    wtype_t* kptr = (wtype_t*)w + c*chi*maxk;
                    mtype_t* sptr = (mtype_t*)TM_MATP(in, src_y0, src_x0, dmul?c/dmul:0); 
                    if(maxk==1){ //speed up pointwise conv
                        for (int cc = 0; cc < chi; cc++) { 
                            sum +=  (sumtype_t)sptr[0]  * (sumtype_t)kptr[0];kptr += 1;sptr += 1;
                        }
                    }else {
                        for (int cc = 0; cc < chi; cc++) {
                            for (int k = 0; k < maxk; k++) {
                                sumtype_t val = (sumtype_t)sptr[k_oft[k]];
                                sumtype_t wt  = (sumtype_t)kptr[k];
                                sum += val * wt;
                            }
                            kptr += maxk;
                            sptr += 1;
                        }
                    }
                } else {    //same pad part  //slower
                    int _ky0 = src_y0<0 ? -src_y0 : 0;
                    int _kx0 = src_x0<0 ? -src_x0 : 0;
                    int _ky1 = in->h-src_y0>kh ? kh : in->h-src_y0;
                    int _kx1 = in->w-src_x0>kw ? kw : in->w-src_x0;
                    wtype_t* kptr = (wtype_t*)w + c*chi*maxk;
                    mtype_t* sptr = (mtype_t*)TM_MATP(in, src_y0, src_x0, dmul?c/dmul:0);  //virtual sptr position in pad
                    sum += sum_pad; //pad_all
                    for (int cc = 0; cc < chi; cc++) {
                        for(int _ky=_ky0; _ky<_ky1; _ky++){
                            for(int _kx=_kx0; _kx<_kx1; _kx++){
                                int k = _ky*kw + _kx;
                                sumtype_t val = ((sumtype_t)sptr[k_oft[k]]-in_zp);
                                sumtype_t wt  = (sumtype_t)kptr[k];
                                sum += val * wt;
                            }
                        }
                        kptr += maxk;
                        sptr += 1;
                    }
                }
                sum += b[c];    //bias, fuse with in_zp
                
            #if (TM_MDL_TYPE == TM_MDL_INT8) || (TM_MDL_TYPE == TM_MDL_INT16 )
                #if !TM_FASTSCALE
                    float sumf = sum*ws[c]*in_s;
                #else 
                    sumtype_t sumf = (sum<<TM_FASTSCALE_SHIFT)/scale;
                #endif
            #else
                float sumf = sum;
            #endif
                switch(act){    //activation func
                case TM_ACT_RELU:
                    sumf = sumf>0?sumf:0;
                    break;
                case TM_ACT_RELU6:
                    sumf = sumf>0?sumf:0;
                #if (TM_MDL_TYPE == TM_MDL_FP32) || (!TM_FASTSCALE)
                    sumf = sumf>6?6:sumf;
                #else
                    sumf = sumf>(6<<TM_FASTSCALE_SHIFT)?(6<<TM_FASTSCALE_SHIFT):sumf;
                #endif
                default:
                    break;
                }

            #if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16   //requant 
                #if !TM_FASTSCALE
                    *outp = (mtype_t)(sumf/out_s + out_zp);  //(mtype_t)((int)(sumf/out_s) + out_zp) //(mtype_t)((int)(sumf/out_s +0.5) + out_zp)
                #else 
                    *outp = (mtype_t)(((sumf*os)>>(TM_FASTSCALE_SHIFT+TM_FASTSCALE_SHIFT))+out_zp);
                #endif
            #else
                *outp = (mtype_t)sumf;
            #endif
                outp += out->c;
            }
        }
    }
    return TM_OK;
}

/*************************** TML_GAP **********************************/
tm_err_t __attribute__((weak)) tml_gap(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
{   TM_DBGT_INIT();
    mtype_t* data;
    for(int c=0; c <out->c; c++){
        sumtype_t sum = 0;
        data = in->data + c;
        for(int y=0; y <in->h; y++){
            for(int x=0; x <in->w; x++){
                sum  += ((sumtype_t)(*data));
                data += out->c;
            }
        }
    #if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16 
        out->data[c] = (mtype_t)((sum/((in->h)*(in->w))-in_zp)*in_s/out_s + out_zp); //requant
    #else
        out->data[c] = (mtype_t)(sum/((in->h)*(in->w)));
    #endif
    }
    return TM_OK;
}

/*************************** TML_FC **********************************/
tm_err_t __attribute__((weak)) tml_fc(tm_mat_t* in, tm_mat_t* out,  wtype_t* w, btype_t* b, \
    sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
{   TM_DBGT_INIT();
    mtype_t* data = in->data;
    for(int c=0; c <out->c; c++){
        sumtype_t sum = 0;
        for(int cc=0; cc <in->c; cc++){
            sum += ((sumtype_t)data[cc])*(w[c*in->c+cc]);
        }
        sum += b[c];    //fuse with zp
    #if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16 
        out->data[c] = (mtype_t)(sum*in_s*ws[0]/out_s + out_zp); //requant
    #else
        out->data[c] = (mtype_t)(sum);
    #endif
    }
    return TM_OK;
}

/*************************** TML_SOFTMAX **********************************/
tm_err_t __attribute__((weak)) tml_softmax(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
{   TM_DBGT_INIT(); //note we have float size output buf even in INT8/INT16 mode
    mtype_t* din = in->data;
    float*  dout = (float*)(out->data); 
    float   dmax =  -FLT_MAX;
    for(int c=0; c <in->c; c++){
    #if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16 
        dout[c] = (float)((sumtype_t)din[c] - in_zp)*in_s;
    #else
        dout[c] = din[c];
    #endif
        if(dout[c] > dmax) dmax = dout[c];
    }
    float sum = 0;
    for(int c=0; c <in->c; c++){
        dout[c] -= dmax;
        dout[c] = (float)exp(dout[c]);
        sum     += dout[c];
        dout[c] -= 0.000001;  //prevent 1.0 value (cause 256 overflow)
    }
    for(int c=0; c <in->c; c++){  //int8/int16 <= fp32, so it is ok
    #if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16 
        out->data[c] = (mtype_t)(dout[c]/sum/out_s + out_zp); //requant
    #else
        out->data[c] = (mtype_t)(dout[c]/sum);
    #endif
    }
    return TM_OK;
}

/*************************** TML_RESHAPE **********************************/
tm_err_t __attribute__((weak)) tml_reshape(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
{   
    //in fact do nothing... out shape
    return TM_OK;
}

tm_model.cpp

/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tinymaix.h"
#include "string.h"

//dirty port for avr chip
#if 1
#define LBUF_LEN (360)
#define LAYERBUF_SIZE (LBUF_LEN)
static uint8_t l_buf[LAYERBUF_SIZE];
static const uint8_t* l_bin;
#define LAYER_BODY (l_buf+sizeof(tm_mdlbin_t))
#define TM_READ_LAYER(dst,src,num)   memcpy((void*)(dst),(void*)(src),(num))
//load model
//mdl: model handle; bin: model bin buf; buf: main buf for middle output; cb: layer callback; 
//in: return input mat, include buf addr; //you can ignore it if use static buf
tm_err_t tm_load  (tm_mdl_t* mdl, const uint8_t* bin, uint8_t*buf, tm_cb_t cb, tm_mat_t* in)
{   l_bin = bin;
    TM_READ_LAYER(l_buf, bin, sizeof(tm_mdlbin_t));
    tm_mdlbin_t* mdl_bin = (tm_mdlbin_t*)l_buf;
    char* tmp = (char*)mdl_bin;
    TM_PRINTF("%c%c%c%c\n",tmp[0],tmp[1],tmp[2],tmp[3]);
    if(tmp[0]!='M'||tmp[1]!='A'||tmp[2]!='I'||tmp[3]!='X')return TM_ERR_MAGIC;
    //if(mdl_bin->magic != TM_MDL_MAGIC)   return TM_ERR_MAGIC; //avr is big endian...
    if(mdl_bin->mdl_type != TM_MDL_TYPE) return TM_ERR_MDLTYPE;
    mdl->b          = mdl_bin;
    mdl->cb         = (void*)cb; 
    if(buf == NULL) {
        mdl->buf        = (uint8_t*)tm_malloc(mdl->b->buf_size);
        if(mdl->buf == NULL) return TM_ERR_OOM;
        mdl->main_alloc = 1;
    } else {
        mdl->buf = buf;
        mdl->main_alloc = 0;
    }
    if(mdl->b->sub_size > 0) {
        mdl->subbuf = (uint8_t*)tm_malloc(mdl->b->sub_size);
        if(mdl->subbuf == NULL) return TM_ERR_OOM;
    } else mdl->subbuf = NULL;
    mdl->layer_i    = 0;
    mdl->layer_body = mdl->b->layers_body;
    memcpy((void*)in, (void*)mdl->b->in_dims, sizeof(tm_mat_t));
    in->data = (mtype_t*)mdl->buf; //input at 0 oft
    return TM_OK;
}

void tm_unload(tm_mdl_t* mdl)               
{
    if(mdl->main_alloc) tm_free(mdl->buf);
    if(mdl->subbuf) tm_free(mdl->subbuf);
    return;
}


//preprocess data input
tm_err_t tm_preprocess(tm_mdl_t* mdl, tm_pp_t pp_type, tm_mat_t* in, tm_mat_t* out)
{   tm_mdlbin_t* b = (tm_mdlbin_t*)l_buf;
    TM_READ_LAYER(LAYER_BODY,l_bin+sizeof(tm_mdlbin_t),sizeof(tml_head_t));
    tml_head_t* l0h = (tml_head_t*)(l_buf+sizeof(tm_mdlbin_t));
    TM_READ_LAYER(LAYER_BODY,l_bin+sizeof(tm_mdlbin_t),l0h->size);

    sctype_t in_s = l0h->in_s;
    zptype_t in_zp= l0h->in_zp;
    int in_size = in->h*in->w*in->c;
    switch(pp_type){
#if TM_MDL_TYPE != TM_MDL_FP32
    case TMPP_FP2INT:
        for(int i=0; i<in_size; i++)
            out->data[i] = (mtype_t)(in->dataf[i]/in_s + in_zp);
        break;
    case TMPP_UINT2INT:
        for(int i=0; i<in_size; i++)
            out->data[i] = ((mtype_t)(((uint8_t*)(in->data))[i]-128))<<UINT2INT_SHIFT;
            //out->data[i] = ((mtype_t)(((uint8_t)pgm_read_byte(in->data+i)-128)))<<UINT2INT_SHIFT;
        break;
#else
    case TMPP_UINT2FP01:
        for(int i=0; i<in_size; i++)
            out->dataf[i] = (((uint8_t*)(in->data))[i])/255.0; 
        break;
    case TMPP_UINT2FPN11:
        for(int i=0; i<in_size; i++)
            out->dataf[i] = ((((uint8_t*)(in->data))[i])-128)/128.0;
        break;
#endif
    default:    //don't do anything
        out->data = in->data;
        break;
    }
    
    return TM_OK;
}


//run model
//mdl: model handle; in: input mat; out: output mat
tm_err_t tm_run(tm_mdl_t* mdl, tm_mat_t* in, tm_mat_t* out)
{
    tm_mat_t _in, _out;
    tm_err_t res = TM_OK;
    int out_idx = 0;
    tml_head_t* h;
    memcpy((void*)&_in, (void*)in, sizeof(tm_mat_t));     
    mdl->layer_body = (uint8_t*)(l_bin+sizeof(tm_mdlbin_t));
    for(mdl->layer_i = 0; mdl->layer_i < mdl->b->layer_cnt; mdl->layer_i++){
        TM_READ_LAYER(LAYER_BODY,mdl->layer_body,sizeof(tml_head_t));
        h = (tml_head_t*)(LAYER_BODY);
        TM_READ_LAYER(LAYER_BODY,mdl->layer_body,h->size);
        
        if(mdl->layer_i>0) {
            _in.data  = (mtype_t*)(mdl->buf + h->in_oft);
            memcpy((void*)&_in, (void*)(h->in_dims), sizeof(uint16_t)*4);
        }
        _out.data = (mtype_t*)(mdl->buf + h->out_oft);
        memcpy((void*)&_out, (void*)(h->out_dims), sizeof(uint16_t)*4);
        switch(h->type){
        case TML_CONV2D: 
        case TML_DWCONV2D:{
            tml_conv2d_dw_t* l = (tml_conv2d_dw_t*)(LAYER_BODY);
            res = tml_conv2d_dwconv2d(&_in, &_out, (wtype_t*)(LAYER_BODY + l->w_oft), (btype_t*)(LAYER_BODY + l->b_oft), \
                l->kernel_w, l->kernel_h, l->stride_w, l->stride_h, l->dilation_w, l->dilation_h, \
                l->act, l->pad[0], l->pad[1], l->pad[2], l->pad[3], l->depth_mul, \
                (sctype_t*)(LAYER_BODY + l->ws_oft), h->in_s, h->in_zp, h->out_s, h->out_zp);
            break;}
        case TML_GAP: {
            tml_gap_t* l = (tml_gap_t*)(LAYER_BODY);
            res = tml_gap(&_in, &_out, h->in_s, h->in_zp, h->out_s, h->out_zp);
            break;}
        case TML_FC: {
            tml_fc_t* l = (tml_fc_t*)(LAYER_BODY);
            res = tml_fc(&_in, &_out, (wtype_t*)(LAYER_BODY + l->w_oft), (btype_t*)(LAYER_BODY + l->b_oft), \
                (sctype_t*)(LAYER_BODY + l->ws_oft), h->in_s, h->in_zp, h->out_s, h->out_zp);
            break;}
        case TML_SOFTMAX: {
            tml_softmax_t* l = (tml_softmax_t*)(LAYER_BODY);
            res = tml_softmax(&_in, &_out, h->in_s, h->in_zp, h->out_s, h->out_zp);
            break; }
        case TML_RESHAPE: {
            tml_reshape_t* l = (tml_reshape_t*)(LAYER_BODY);
            res = tml_reshape(&_in, &_out, h->in_s, h->in_zp, h->out_s, h->out_zp);
            break; }
        default:
            res = TM_ERR_LAYERTYPE;
            break;
        }
        if(res != TM_OK) return res;
        if(mdl->cb) ((tm_cb_t)mdl->cb)(mdl, h);    //layer callback
        if(h->is_out) {
            memcpy((void*)(&out[out_idx]), (void*)(&(h->out_dims)), sizeof(uint16_t)*4);
            if(mdl->b->out_deq == 0 || TM_MDL_TYPE == TM_MDL_FP32) //fp32 do not need deq
                out[out_idx].data = (mtype_t*)(TML_GET_OUTPUT(mdl, h));
            else {
                int out_size = h->out_dims[1]*h->out_dims[2]*h->out_dims[3];
                float* outf = (float*)(TML_GET_OUTPUT(mdl, h) + (out_size+7)/8*8);
                for(int i=0; i<out_size; i++) //do dequant
                    outf[i] = TML_DEQUANT(h, (TML_GET_OUTPUT(mdl, h))[i]);
                out[out_idx].dataf = outf;
            }
            out_idx += 1;
        }
        mdl->layer_body += (h->size);
    }
    return TM_OK;
}

#endif

效果

需要多执行几次才能得到识别结果.

E/CPP: NULL
E/CPP: 
E/CPP: 9
E/CPP: : 
E/CPP: NULL
E/CPP: 
E/CPP: ### Predict output is: Number 
E/CPP: 2
E/CPP: , prob=
E/CPP: 89
E/CPP: 
手写数字 结果
posted @ 2023-08-10 05:04  qsBye  阅读(26)  评论(0编辑  收藏  举报