tflite移植

移植AI到Android平台

提取模型

import tensorflow as tf
    
matrix = tf.placeholder(tf.float32, [1.0, 2.0], name='matrix_')
out = tf.add(matrix, matrix, name='output')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #下面一行可以省去,仅作测试
    print(sess.run(out, feed_dict={matrix:[[2.0, 3.0]]}))
    converter = tf.lite.TFLiteConverter.from_session(sess, [matrix], [out])
    tflite_model = converter.convert()
    open("D:/s/converted_model.tflite", "wb").write(tflite_model)

生成模型converted_model.tflite

python读取模型并测试

import numpy as np
import tensorflow as tf

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="D:/s/converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


input_shape = input_details[0]['shape']
print(input_shape) # [1 2]
input_data = np.array([[2, 4]], dtype=np.float32)
#input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

Android读取模型并测试

Android源码:

//MainActivity.java
package com.learn.testtf;

import android.Manifest;
import android.content.pm.PackageManager;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.support.annotation.NonNull;
import android.support.v4.app.ActivityCompat;
import android.support.v4.content.ContextCompat;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.Toast;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;

import org.tensorflow.lite.Interpreter;

public class MainActivity extends AppCompatActivity {

    private static final String TAG = "Test";

    private Interpreter tflite;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        loadModule();
    }

    private void loadModule() {
        String model = "converted_model";
        try {
            Interpreter.Options options = new Interpreter.Options();
            options.setNumThreads(4);
            options.setUseNNAPI(true);
options.setAllowFp16PrecisionForFp32(true);
tflite = new Interpreter(loadModelFile(model), options);
            Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
        } catch (IOException e) {
            Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
            e.printStackTrace();
        }
    }

    private MappedByteBuffer loadModelFile(String model) throws IOException {
        AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

    public void test(View view) {
        Log.d(TAG, "-----test()----");
        try {
            float[][] labelProbArray = new float[1][2];
            ByteBuffer inputData = ByteBuffer.allocateDirect(2 * 4);
            inputData.order(ByteOrder.nativeOrder());
            inputData.putFloat(7.0f);
            inputData.putFloat(8.0f);
            for (int i = 0;i < 100000;i++) {
                tflite.run(inputData, labelProbArray);
                Log.d(TAG, "labelProbArray[0]" + labelProbArray[0][0]);
                Log.d(TAG, "labelProbArray[1]" + labelProbArray[0][1]);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

 

apply plugin: 'com.android.application'

android {
    compileSdkVersion 28
    defaultConfig {
        applicationId "com.learn.testtf"
        minSdkVersion 26
        targetSdkVersion 28
        versionCode 1
        versionName "1.0"
        testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
    }
    buildTypes {
        release {
            minifyEnabled false
            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
        }
    }

    //set no compress models
    aaptOptions {
        noCompress "tflite"
    }
}

dependencies {
    implementation fileTree(dir: 'libs', include: ['*.jar'])
    implementation 'com.android.support:appcompat-v7:28.0.0'
    implementation 'com.android.support.constraint:constraint-layout:1.1.3'
    testImplementation 'junit:junit:4.12'
    androidTestImplementation 'com.android.support.test:runner:1.0.2'
    androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
    implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
}

另外将converted_model.tflite放入asserts目录中

 

posted @ 2019-04-13 16:13  牧 天  阅读(640)  评论(0)    收藏  举报