tflite移植
移植AI到Android平台
提取模型
import tensorflow as tf matrix = tf.placeholder(tf.float32, [1.0, 2.0], name='matrix_') out = tf.add(matrix, matrix, name='output') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) #下面一行可以省去,仅作测试 print(sess.run(out, feed_dict={matrix:[[2.0, 3.0]]})) converter = tf.lite.TFLiteConverter.from_session(sess, [matrix], [out]) tflite_model = converter.convert() open("D:/s/converted_model.tflite", "wb").write(tflite_model)
生成模型converted_model.tflite
python读取模型并测试
import numpy as np import tensorflow as tf # Load TFLite model and allocate tensors. interpreter = tf.lite.Interpreter(model_path="D:/s/converted_model.tflite") interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() input_shape = input_details[0]['shape'] print(input_shape) # [1 2] input_data = np.array([[2, 4]], dtype=np.float32) #input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) print(output_data)
Android读取模型并测试
Android源码:
//MainActivity.java package com.learn.testtf; import android.Manifest; import android.content.pm.PackageManager; import android.content.res.AssetFileDescriptor; import android.graphics.Bitmap; import android.support.annotation.NonNull; import android.support.v4.app.ActivityCompat; import android.support.v4.content.ContextCompat; import android.support.v7.app.AppCompatActivity; import android.os.Bundle; import android.util.Log; import android.view.View; import android.widget.Toast; import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.ArrayList; import java.util.List; import org.tensorflow.lite.Interpreter; public class MainActivity extends AppCompatActivity { private static final String TAG = "Test"; private Interpreter tflite; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); loadModule(); } private void loadModule() { String model = "converted_model"; try { Interpreter.Options options = new Interpreter.Options(); options.setNumThreads(4); options.setUseNNAPI(true);
options.setAllowFp16PrecisionForFp32(true);
tflite = new Interpreter(loadModelFile(model), options);
Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show(); } catch (IOException e) { Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show(); e.printStackTrace(); } } private MappedByteBuffer loadModelFile(String model) throws IOException { AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite"); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } public void test(View view) { Log.d(TAG, "-----test()----"); try { float[][] labelProbArray = new float[1][2]; ByteBuffer inputData = ByteBuffer.allocateDirect(2 * 4); inputData.order(ByteOrder.nativeOrder()); inputData.putFloat(7.0f); inputData.putFloat(8.0f); for (int i = 0;i < 100000;i++) { tflite.run(inputData, labelProbArray); Log.d(TAG, "labelProbArray[0]" + labelProbArray[0][0]); Log.d(TAG, "labelProbArray[1]" + labelProbArray[0][1]); } } catch (Exception e) { e.printStackTrace(); } } }
apply plugin: 'com.android.application' android { compileSdkVersion 28 defaultConfig { applicationId "com.learn.testtf" minSdkVersion 26 targetSdkVersion 28 versionCode 1 versionName "1.0" testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" } buildTypes { release { minifyEnabled false proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' } } //set no compress models aaptOptions { noCompress "tflite" } } dependencies { implementation fileTree(dir: 'libs', include: ['*.jar']) implementation 'com.android.support:appcompat-v7:28.0.0' implementation 'com.android.support.constraint:constraint-layout:1.1.3' testImplementation 'junit:junit:4.12' androidTestImplementation 'com.android.support.test:runner:1.0.2' androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2' implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly' }
另外将converted_model.tflite放入asserts目录中
                    
                
                
            
        
浙公网安备 33010602011771号