tflite识别图片
package com.learn.testtf; import android.Manifest; import android.content.pm.PackageManager; import android.content.res.AssetFileDescriptor; import android.content.res.AssetManager; import android.graphics.Bitmap; import android.support.annotation.NonNull; import android.support.v4.app.ActivityCompat; import android.support.v4.content.ContextCompat; import android.support.v7.app.AppCompatActivity; import android.os.Bundle; import android.util.Log; import android.view.View; import android.widget.Toast; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.ArrayList; import java.util.List; import org.tensorflow.lite.Interpreter; public class MainActivity extends AppCompatActivity { private static final String TAG = "Test"; private int[] ddims = {1, 3, 224, 224}; private List<String> resultLabel = new ArrayList<>(); private Interpreter tflite; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); readCacheLabelFromLocalFile(); request_permissions(); loadModule(); } private void loadModule() { String model; if (true) { model = "mobilenet_v1"; } else { model = "converted_model"; } try { Interpreter.Options options = new Interpreter.Options(); options.setNumThreads(10); options.setUseNNAPI(true); options.setAllowFp16PrecisionForFp32(true); tflite = new Interpreter(loadModelFile(model), options); Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show(); } catch (IOException e) { Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show(); e.printStackTrace(); } } private MappedByteBuffer loadModelFile(String model) throws IOException { AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite"); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } public void onClick(View view) { Log.d(TAG, "-----onClick()----"); for (int i = 0; i < 1000;i++) { predict_image(); } } private void test() { try { new Thread() { @Override public void run() { for (int i = 0; i < 100000; i++) { float[][] labelProbArray = new float[1][2]; ByteBuffer inputData = ByteBuffer.allocateDirect(2 * 4); inputData.order(ByteOrder.nativeOrder()); inputData.putFloat(7.0f); inputData.putFloat(8.0f); tflite.run(inputData, labelProbArray); Log.d(TAG, "labelProbArray[0]" + labelProbArray[0][0]); Log.d(TAG, "labelProbArray[1]" + labelProbArray[0][1]); } } }.start(); } catch (Exception e) { e.printStackTrace(); } } // predict image private void predict_image() { String image_path = "/storage/emulated/0/Pictures/pic/pic/Koala.jpg"; // picture to float array Bitmap bmp = PhotoUtil.getScaleBitmap(image_path); ByteBuffer inputData = PhotoUtil.getScaledMatrix(bmp, ddims); try { // Data format conversion takes too long // Log.d("inputData", Arrays.toString(inputData)); float[][] labelProbArray = new float[1][1001]; long start = System.currentTimeMillis(); // get predict result tflite.run(inputData, labelProbArray); long end = System.currentTimeMillis(); long time = end - start; float[] results = new float[labelProbArray[0].length]; System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length); // show predict result and time int r = get_max_result(results); String show_text = "result:" + r + "\nname:" + resultLabel.get(r) + "\nprobability:" + results[r] + "\ntime:" + time + "ms"; Log.d(TAG, "show_text:" + show_text); } catch (Exception e) { e.printStackTrace(); } } private int get_max_result(float[] result) { float probability = result[0]; int r = 0; for (int i = 0; i < result.length; i++) { if (probability < result[i]) { probability = result[i]; r = i; } } return r; } // request permissions private void request_permissions() { List<String> permissionList = new ArrayList<>(); if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) { permissionList.add(Manifest.permission.WRITE_EXTERNAL_STORAGE); } if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) { permissionList.add(Manifest.permission.READ_EXTERNAL_STORAGE); } // if list is not empty will request permissions if (!permissionList.isEmpty()) { ActivityCompat.requestPermissions(this, permissionList.toArray(new String[permissionList.size()]), 1); } } @Override public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { super.onRequestPermissionsResult(requestCode, permissions, grantResults); switch (requestCode) { case 1: if (grantResults.length > 0) { for (int i = 0; i < grantResults.length; i++) { int grantResult = grantResults[i]; if (grantResult == PackageManager.PERMISSION_DENIED) { String s = permissions[i]; Toast.makeText(this, s + " permission was denied", Toast.LENGTH_SHORT).show(); } } } break; } } private void readCacheLabelFromLocalFile() { try { AssetManager assetManager = getApplicationContext().getAssets(); BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("cacheLabel.txt"))); String readLine = null; while ((readLine = reader.readLine()) != null) { resultLabel.add(readLine); } reader.close(); } catch (Exception e) { Log.e(TAG, "error " + e); } } }
package com.learn.testtf; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import java.nio.ByteBuffer; import java.nio.ByteOrder; public class PhotoUtil { // TensorFlow model,get predict data public static ByteBuffer getScaledMatrix(Bitmap bitmap, int[] ddims) { ByteBuffer imgData = ByteBuffer.allocateDirect(ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4); imgData.order(ByteOrder.nativeOrder()); // get image pixel int[] pixels = new int[ddims[2] * ddims[3]]; Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[2], ddims[3], false); bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[2], ddims[3]); int pixel = 0; for (int i = 0; i < ddims[2]; ++i) { for (int j = 0; j < ddims[3]; ++j) { final int val = pixels[pixel++]; imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f)); imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f)); imgData.putFloat((((val & 0xFF) - 128f) / 128f)); } } if (bm.isRecycled()) { bm.recycle(); } return imgData; } // compress picture public static Bitmap getScaleBitmap(String filePath) { BitmapFactory.Options opt = new BitmapFactory.Options(); opt.inJustDecodeBounds = true; BitmapFactory.decodeFile(filePath, opt); int bmpWidth = opt.outWidth; int bmpHeight = opt.outHeight; int maxSize = 500; // compress picture with inSampleSize opt.inSampleSize = 1; while (true) { if (bmpWidth / opt.inSampleSize < maxSize || bmpHeight / opt.inSampleSize < maxSize) { break; } opt.inSampleSize *= 2; } opt.inJustDecodeBounds = false; return BitmapFactory.decodeFile(filePath, opt); } }
mobilenet_v1.tflite
cacheLabel.txt

浙公网安备 33010602011771号