Cloned Code for TensorFlow Android
See https://www.tensorflow.org/mobile/ for latest details and more specifically https://www.tensorflow.org/mobile/android_build
How to Clone a current example
git clone https://github.com/tensorflow/tensorflow
now you will get the Android example in following directory
File->Open Project
Select directory shown above
NOW you may have to correct any issues with gradle file updates being needed.
Open the build.gradle file (you can go to 1:Project in the side panel and find it under the Gradle Scripts zippy under Android). Look for the nativeBuildSystem variable and set it to none if it isn't already:
// set to 'bazel', 'cmake', 'makefile', 'none' def nativeBuildSystem = 'none'
Click Run button (the green arrow) or use Run -> Run 'android' from the top menu.
If it asks you to use Instant Run, click Proceed Without Instant Run.
Also, you need to have an Android device plugged in with developer options enabled at this point. See here for more details on setting up developer devices.
This cloned code contains multiple launcable activities-
Here is the result of running the TF Detection (recognition AND localization)
--> it is trained on ImageNet (1000 objects)
NOTE: Object tracking is currently NOT supported in the Java/Android example that is the TFDetect activity based application -- so it is only still object identification and localiztion but, NOT tracking --see README file in the cloned code that says "object tracking is not available in the "TF Detect" activity. "
Explaining the Code and Project Setup
-
TensorFlow is written in C++.
-
In order to build for android, we have to use JNI(Java Native Interface) to call the c++ functions like loadModel, getPredictions, etc.
Android Inference Library
Because Android apps need to be written in Java, and core TensorFlow is in C++, TensorFlow has a JNI library to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. You can see the full documentation for the minimal set of methods inTensorFlowInferenceInterface.java
<?xml version="1.0" encoding="UTF-8"?> <!-- Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->
<manifest xmlns:android="http://schemas.android.com/apk/res/android" package="org.tensorflow.demo">
<uses-permission android:name="android.permission.CAMERA" /> <uses-feature android:name="android.hardware.camera" /> <uses-feature android:name="android.hardware.camera.autofocus" /> <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/> <uses-permission android:name="android.permission.RECORD_AUDIO" />
<uses-sdk android:minSdkVersion="21" android:targetSdkVersion="23" />
<application android:allowBackup="true" android:debuggable="true" android:label="@string/app_name" android:icon="@drawable/ic_launcher" android:theme="@style/MaterialTheme">
<activity android:name="org.tensorflow.demo.ClassifierActivity" android:screenOrientation="portrait" android:label="@string/activity_name_classification"> <intent-filter> <action android:name="android.intent.action.MAIN" /> <category android:name="android.intent.category.LAUNCHER" /> </intent-filter> </activity>
<activity android:name="org.tensorflow.demo.DetectorActivity" android:screenOrientation="portrait" android:label="@string/activity_name_detection"> <intent-filter> <action android:name="android.intent.action.MAIN" /> <category android:name="android.intent.category.LAUNCHER" /> </intent-filter> </activity>
<activity android:name="org.tensorflow.demo.StylizeActivity" android:screenOrientation="portrait" android:label="@string/activity_name_stylize"> <intent-filter> <action android:name="android.intent.action.MAIN" /> <category android:name="android.intent.category.LAUNCHER" /> </intent-filter> </activity>
<activity android:name="org.tensorflow.demo.SpeechActivity" android:screenOrientation="portrait" android:label="@string/activity_name_speech"> <intent-filter> <action android:name="android.intent.action.MAIN" /> <category android:name="android.intent.category.LAUNCHER" /> </intent-filter> </activity> </application>
</manifest>
|
<<ASK permission to use Camera as well as
write storage, one of the activities does audio recognition so also asks for that too
<< there are 4 activities defined as launcher so we have 4 apps associated with them that get placed in our device
<< ClassifierActivity that will simply label the objects found (identity only)
<<DetectorActivity that will both identify and locate object found
<< Actity to produce special effects on image
<<Activity to perform Speech Processing |
it is relatively complex --you can read on your own but, note that we are running with nativeBuildSystem = none and hence the following code that is inside it is executed which...automatically downloads the latest stable version of TensorFlow as an AAR and installs it in your project.
dependencies { if (nativeBuildSystem == 'cmake' || nativeBuildSystem == 'none') { compile 'org.tensorflow:tensorflow-android:+' } }
res/layout/activity_camera.xml
<?xml version="1.0" encoding="utf-8"?> <!-- Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> <FrameLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:tools="http://schemas.android.com/tools" android:id="@+id/container" android:layout_width="match_parent" android:layout_height="match_parent" android:background="#000" tools:context="org.tensorflow.demo.CameraActivity" />
|
java (org.tensorflow.demo) / CameraActivity.java
/* * Copyright 2016 The TensorFlow Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package org.tensorflow.demo;
import android.Manifest; import android.app.Activity; import android.app.Fragment; import android.content.Context; import android.content.pm.PackageManager; import android.hardware.Camera; import android.hardware.camera2.CameraAccessException; import android.hardware.camera2.CameraCharacteristics; import android.hardware.camera2.CameraManager; import android.hardware.camera2.params.StreamConfigurationMap; import android.media.Image; import android.media.Image.Plane; import android.media.ImageReader; import android.media.ImageReader.OnImageAvailableListener; import android.os.Build; import android.os.Bundle; import android.os.Handler; import android.os.HandlerThread; import android.os.Trace; import android.util.Size; import android.view.KeyEvent; import android.view.Surface; import android.view.WindowManager; import android.widget.Toast; import java.nio.ByteBuffer; import org.tensorflow.demo.env.ImageUtils; import org.tensorflow.demo.env.Logger; import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
//declares the class as implementing androi's Camera class PreviewCallback to get new frames
// does not use OpenCV public abstract class CameraActivity extends Activity implements OnImageAvailableListener, Camera.PreviewCallback { private static final Logger LOGGER = new Logger();
private static final int PERMISSIONS_REQUEST = 1;
// setups for permissions to use camera and storage private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA; private static final String PERMISSION_STORAGE = Manifest.permission.WRITE_EXTERNAL_STORAGE;
private boolean debug = false;
private Handler handler; private HandlerThread handlerThread; private boolean useCamera2API; private boolean isProcessingFrame = false; private byte[][] yuvBytes = new byte[3][]; private int[] rgbBytes = null; private int yRowStride;
protected int previewWidth = 0; protected int previewHeight = 0;
private Runnable postInferenceCallback; private Runnable imageConverter;
@Override protected void onCreate(final Bundle savedInstanceState) { LOGGER.d("onCreate " + this); super.onCreate(null); getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
// loads layout activity_camera.xm which contains only a FrameLayout widget where will // display the camera image and any detection results
setContentView(R.layout.activity_camera);
// sees if the app has permissions set it requires otherwise requests permissions
// then calls setFragment() method below if (hasPermission()) { setFragment(); //this creates instance of CameraConnectFragment helper class (see source) that handle camera and preview // as well as setup of classifier and will display any classification results } else { requestPermission(); } }
private byte[] lastPreviewFrame;
protected int[] getRgbBytes() { imageConverter.run(); return rgbBytes; }
protected int getLuminanceStride() { return yRowStride; }
protected byte[] getLuminance() { return yuvBytes[0]; }
/** * Callback for android.hardware.Camera API -- once we recieve a call back that the camera is ready to capture images * setup our CLassifier to be able to process the images as they come in */ @Override public void onPreviewFrame(final byte[] bytes, final Camera camera) { if (isProcessingFrame) { LOGGER.w("Dropping frame!"); return; }
try { // Initialize the storage bitmaps once when the resolution is known. if (rgbBytes == null) { Camera.Size previewSize = camera.getParameters().getPreviewSize(); previewHeight = previewSize.height; previewWidth = previewSize.width; rgbBytes = new int[previewWidth * previewHeight];
// calls the abstracted method onPreviewSizeChosen() in the appropriate child class // like DetectorActivity that uses the helper class Classifier below to setup the appropriate // type of pretrained Classifier. // the connection to the camera and preview of image is handled by an instance of CameraConnectionFragment // class found in the source code (see above) onPreviewSizeChosen(new Size(previewSize.width, previewSize.height), 90); // method will setup classifier } } catch (final Exception e) { LOGGER.e(e, "Exception!"); return; }
isProcessingFrame = true; lastPreviewFrame = bytes; yuvBytes[0] = bytes; yRowStride = previewWidth;
imageConverter = new Runnable() { @Override public void run() { ImageUtils.convertYUV420SPToARGB8888(bytes, previewWidth, previewHeight, rgbBytes); } };
postInferenceCallback = new Runnable() { @Override public void run() { camera.addCallbackBuffer(bytes); isProcessingFrame = false; } }; processImage(); //this code will process the image using the classifier setup inside the extending class like DetectorActivity }
/** */ @Override public void onImageAvailable(final ImageReader reader) { //We need wait until we have some size from onPreviewSizeChosen if (previewWidth == 0 || previewHeight == 0) { return; } if (rgbBytes == null) { rgbBytes = new int[previewWidth * previewHeight]; } try { final Image image = reader.acquireLatestImage();
if (image == null) { return; }
if (isProcessingFrame) { image.close(); return; } isProcessingFrame = true; Trace.beginSection("imageAvailable"); final Plane[] planes = image.getPlanes(); fillBytes(planes, yuvBytes); yRowStride = planes[0].getRowStride(); final int uvRowStride = planes[1].getRowStride(); final int uvPixelStride = planes[1].getPixelStride();
imageConverter = new Runnable() { @Override public void run() { ImageUtils.convertYUV420ToARGB8888( yuvBytes[0], yuvBytes[1], yuvBytes[2], previewWidth, previewHeight, yRowStride, uvRowStride, uvPixelStride, rgbBytes); } };
postInferenceCallback = new Runnable() { @Override public void run() { image.close(); isProcessingFrame = false; } };
processImage(); //process the image using the classifier setup inside the extending class like DetectorActivity } catch (final Exception e) { LOGGER.e(e, "Exception!"); Trace.endSection(); return; } Trace.endSection(); }
@Override public synchronized void onStart() { LOGGER.d("onStart " + this); super.onStart(); }
@Override public synchronized void onResume() { LOGGER.d("onResume " + this); super.onResume();
handlerThread = new HandlerThread("inference"); handlerThread.start(); handler = new Handler(handlerThread.getLooper()); }
@Override public synchronized void onPause() { LOGGER.d("onPause " + this);
if (!isFinishing()) { LOGGER.d("Requesting finish"); finish(); }
handlerThread.quitSafely(); try { handlerThread.join(); handlerThread = null; handler = null; } catch (final InterruptedException e) { LOGGER.e(e, "Exception!"); }
super.onPause(); }
@Override public synchronized void onStop() { LOGGER.d("onStop " + this); super.onStop(); }
@Override public synchronized void onDestroy() { LOGGER.d("onDestroy " + this); super.onDestroy(); }
protected synchronized void runInBackground(final Runnable r) { if (handler != null) { handler.post(r); } }
@Override public void onRequestPermissionsResult( final int requestCode, final String[] permissions, final int[] grantResults) { if (requestCode == PERMISSIONS_REQUEST) { if (grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED && grantResults[1] == PackageManager.PERMISSION_GRANTED) { setFragment(); } else { requestPermission(); } } }
private boolean hasPermission() { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { return checkSelfPermission(PERMISSION_CAMERA) == PackageManager.PERMISSION_GRANTED && checkSelfPermission(PERMISSION_STORAGE) == PackageManager.PERMISSION_GRANTED; } else { return true; } }
private void requestPermission() { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { if (shouldShowRequestPermissionRationale(PERMISSION_CAMERA) || shouldShowRequestPermissionRationale(PERMISSION_STORAGE)) { Toast.makeText(CameraActivity.this, "Camera AND storage permission are required for this demo", Toast.LENGTH_LONG).show(); } requestPermissions(new String[] {PERMISSION_CAMERA, PERMISSION_STORAGE}, PERMISSIONS_REQUEST); } }
// Returns true if the device supports the required hardware level, or better. private boolean isHardwareLevelSupported( CameraCharacteristics characteristics, int requiredLevel) { int deviceLevel = characteristics.get(CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL); if (deviceLevel == CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_LEGACY) { return requiredLevel == deviceLevel; } // deviceLevel is not LEGACY, can use numerical sort return requiredLevel <= deviceLevel; }
private String chooseCamera() { final CameraManager manager = (CameraManager) getSystemService(Context.CAMERA_SERVICE); try { for (final String cameraId : manager.getCameraIdList()) { final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
// We don't use a front facing camera in this sample. final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { continue; }
final StreamConfigurationMap map = characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
if (map == null) { continue; }
useCamera2API = isHardwareLevelSupported(characteristics, CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_FULL); LOGGER.i("Camera API lv2?: %s", useCamera2API); return cameraId; } } catch (CameraAccessException e) { LOGGER.e(e, "Not allowed to access camera"); }
return null; }
// this method opens up connection to camera and then when connected sets up size of frame that will be displayed
//uses helper class CameraConnectionFragment that handles connection to the camera, and preview of the camera capture
protected void setFragment() { String cameraId = chooseCamera();
Fragment fragment; if (useCamera2API) { CameraConnectionFragment camera2Fragment = // helper class (see source directory) for handling connection to camera,etc. CameraConnectionFragment.newInstance( new CameraConnectionFragment.ConnectionCallback() { @Override public void onPreviewSizeChosen(final Size size, final int rotation) { //override so only setting up sizing here previewHeight = size.getHeight(); previewWidth = size.getWidth(); CameraActivity.this.onPreviewSizeChosen(size, rotation); } }, this, getLayoutId(), getDesiredPreviewFrameSize());
camera2Fragment.setCamera(cameraId); //setup camera associated with a helper class CameraConnectionFragment // instance that handles camera connection, disconnection ,etc. fragment = camera2Fragment; } else { fragment = new LegacyCameraConnectionFragment(this, getLayoutId(), getDesiredPreviewFrameSize()); }
getFragmentManager() // display the camera2Fragment instance here that is instance of CameraConnectionFragement .beginTransaction() .replace(R.id.container, fragment) .commit(); }
protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) { // Because of the variable row stride it's not possible to know in // advance the actual necessary dimensions of the yuv planes. for (int i = 0; i < planes.length; ++i) { final ByteBuffer buffer = planes[i].getBuffer(); if (yuvBytes[i] == null) { LOGGER.d("Initializing buffer %d at size %d", i, buffer.capacity()); yuvBytes[i] = new byte[buffer.capacity()]; } buffer.get(yuvBytes[i]); } }
public boolean isDebug() { return debug; }
public void requestRender() { final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay); if (overlay != null) { overlay.postInvalidate(); } }
public void addCallback(final OverlayView.DrawCallback callback) { final OverlayView overlay = (OverlayView) findViewById(R.id.debug_overlay); if (overlay != null) { overlay.addCallback(callback); } }
public void onSetDebug(final boolean debug) {}
@Override public boolean onKeyDown(final int keyCode, final KeyEvent event) { if (keyCode == KeyEvent.KEYCODE_VOLUME_DOWN || keyCode == KeyEvent.KEYCODE_VOLUME_UP) { debug = !debug; requestRender(); onSetDebug(debug); return true; } return super.onKeyDown(keyCode, event); }
protected void readyForNextImage() { if (postInferenceCallback != null) { postInferenceCallback.run(); } }
protected int getScreenOrientation() { switch (getWindowManager().getDefaultDisplay().getRotation()) { case Surface.ROTATION_270: return 270; case Surface.ROTATION_180: return 180; case Surface.ROTATION_90: return 90; default: return 0; } }
protected abstract void processImage();
protected abstract void onPreviewSizeChosen(final Size size, final int rotation); protected abstract int getLayoutId(); protected abstract Size getDesiredPreviewFrameSize(); }
|
The Yolo (you look only once) and original Multibox detectors remain available by modifying DetectorActivity.java
java (org.tensorflow.demo) / DetectorActivity.java and the purpose is to perform both identification and localization using the pretrained model
package org.tensorflow.demo;
import android.graphics.Bitmap; import android.graphics.Bitmap.Config; import android.graphics.Canvas; import android.graphics.Color; import android.graphics.Matrix; import android.graphics.Paint; import android.graphics.Paint.Style; import android.graphics.RectF; import android.graphics.Typeface; import android.media.ImageReader.OnImageAvailableListener; import android.os.SystemClock; import android.util.Size; import android.util.TypedValue; import android.view.Display; import android.view.Surface; import android.widget.Toast; import java.io.IOException; import java.util.LinkedList; import java.util.List; import java.util.Vector; import org.tensorflow.demo.OverlayView.DrawCallback; import org.tensorflow.demo.env.BorderedText; import org.tensorflow.demo.env.ImageUtils; import org.tensorflow.demo.env.Logger; import org.tensorflow.demo.tracking.MultiBoxTracker; import org.tensorflow.demo.R; // Explicit import needed for internal Google builds.
/** * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track * objects. */ public class DetectorActivity extends CameraActivity implements OnImageAvailableListener { private static final Logger LOGGER = new Logger();
/** * There are MULTIPLE OPTIONS available for detection
*
* 1) DEFAULT - Object Detection API code TensorFlowObjectDetectionAPIModel
* 2) Mutibox - TensorflowMutliBoxDetector * 3) YoloDetector - TensorFlowYoloDetector
* see src for details
*/
// Configuration values for the prepackaged multibox model. private static final int MB_INPUT_SIZE = 224; private static final int MB_IMAGE_MEAN = 128; private static final float MB_IMAGE_STD = 128; private static final String MB_INPUT_NAME = "ResizeBilinear"; private static final String MB_OUTPUT_LOCATIONS_NAME = "output_locations/Reshape"; private static final String MB_OUTPUT_SCORES_NAME = "output_scores/Reshape"; private static final String MB_MODEL_FILE = "file:///android_asset/multibox_model.pb"; private static final String MB_LOCATION_FILE = "file:///android_asset/multibox_location_priors.txt";
// Configuration values for the prepackaged Object Detection api model private static final int TF_OD_API_INPUT_SIZE = 300; private static final String TF_OD_API_MODEL_FILE = "file:///android_asset/ssd_mobilenet_v1_android_export.pb"; private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt
";
SPECIAL NOTE: See this paper https://www.cs.unc.edu/~wliu/papers/ssd.pdf to understand what SSD based CNN
--it is a way of carving up the image into regions to process through the CNN for object identification --and the region tells us the location.
// Configuration values for tiny-yolo-voc. Note that the graph is not included with TensorFlow and // must be manually placed in the assets/ directory by the user. // Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via // DarkFlow (https://github.com/thtrieu/darkflow). Sample command: // ./flow --model cfg/tiny-yolo-voc.cfg --load bin/tiny-yolo-voc.weights --savepb --verbalise private static final String YOLO_MODEL_FILE = "file:///android_asset/graph-tiny-yolo-voc.pb"; private static final int YOLO_INPUT_SIZE = 416; private static final String YOLO_INPUT_NAME = "input"; private static final String YOLO_OUTPUT_NAMES = "output"; private static final int YOLO_BLOCK_SIZE = 32;
Which detection model to use: by default uses Tensorflow Object Detection API frozen // checkpoints. Optionally use legacy Multibox (trained using an older version of the API) // or YOLO. private enum DetectorMode { TF_OD_API, MULTIBOX, YOLO; } private static final DetectorMode MODE = DetectorMode.TF_OD_API; //declaring using Object Detector API
// Minimum detection confidence to track a detection --if detection works will only do tracking if // > these minimum confidence levels. private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f; private static final float MINIMUM_CONFIDENCE_MULTIBOX = 0.1f; private static final float MINIMUM_CONFIDENCE_YOLO = 0.25f;
private static final boolean MAINTAIN_ASPECT = MODE == DetectorMode.YOLO;
private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); //going to go for 640x480 images
private static final boolean SAVE_PREVIEW_BITMAP = false; private static final float TEXT_SIZE_DIP = 10;
private Integer sensorOrientation;
private Classifier detector;
private long lastProcessingTimeMs; private Bitmap rgbFrameBitmap = null; private Bitmap croppedBitmap = null; private Bitmap cropCopyBitmap = null;
private boolean computingDetection = false;
private long timestamp = 0;
private Matrix frameToCropTransform; private Matrix cropToFrameTransform;
private MultiBoxTracker tracker;
private byte[] luminanceCopy;
private BorderedText borderedText;
//this is the method that sets up the preview AND sets up the Classifier object detector @Override public void onPreviewSizeChosen(final Size size, final int rotation) { final float textSizePx = TypedValue.applyDimension( TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); borderedText = new BorderedText(textSizePx); borderedText.setTypeface(Typeface.MONOSPACE);
tracker = new MultiBoxTracker(this);
int cropSize = TF_OD_API_INPUT_SIZE;
//if running in YOLO - setup corresponding classifier if (MODE == DetectorMode.YOLO) { detector = TensorFlowYoloDetector.create( getAssets(), YOLO_MODEL_FILE, YOLO_INPUT_SIZE, YOLO_INPUT_NAME, YOLO_OUTPUT_NAMES, YOLO_BLOCK_SIZE); cropSize = YOLO_INPUT_SIZE; } else if (MODE == DetectorMode.MULTIBOX) { //if running Multibox detector = TensorFlowMultiBoxDetector.create( getAssets(), MB_MODEL_FILE, MB_LOCATION_FILE, MB_IMAGE_MEAN, MB_IMAGE_STD, MB_INPUT_NAME, MB_OUTPUT_LOCATIONS_NAME, MB_OUTPUT_SCORES_NAME); cropSize = MB_INPUT_SIZE; } else { //this is running (see above) in Object Detection API mode for classifier try {
// create TensorFlowObjectDetetionAPIModel classifier - // the trained model = TF_OD_API_MODE_FILE (remember have to be packaged/compressed // for mobile use see here for details)
// the labels of outputs = TF_OD_API_LABELS_FILE - these are the objects going to // detect like dog, cat // the input size = TF_OD_API_INPUT_SIZE (set above to 300 ) detector = TensorFlowObjectDetectionAPIModel.create( getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE); cropSize = TF_OD_API_INPUT_SIZE; } catch (final IOException e) { LOGGER.e("Exception initializing classifier!", e); Toast toast = Toast.makeText( getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); toast.show(); finish(); } }
//DONE setting up classifier --now move on to rest of this method setting up preview image size for // display and processing
previewWidth = size.getWidth(); previewHeight = size.getHeight();
sensorOrientation = rotation - getScreenOrientation(); LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);
frameToCropTransform = ImageUtils.getTransformationMatrix( previewWidth, previewHeight, cropSize, cropSize, sensorOrientation, MAINTAIN_ASPECT);
cropToFrameTransform = new Matrix(); frameToCropTransform.invert(cropToFrameTransform);
trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay); trackingOverlay.addCallback( new DrawCallback() { @Override public void drawCallback(final Canvas canvas) { tracker.draw(canvas); if (isDebug()) { tracker.drawDebug(canvas); } } });
addCallback( new DrawCallback() { @Override public void drawCallback(final Canvas canvas) { if (!isDebug()) { return; } final Bitmap copy = cropCopyBitmap; if (copy == null) { return; }
final int backgroundColor = Color.argb(100, 0, 0, 0); canvas.drawColor(backgroundColor);
final Matrix matrix = new Matrix(); final float scaleFactor = 2; matrix.postScale(scaleFactor, scaleFactor); matrix.postTranslate( canvas.getWidth() - copy.getWidth() * scaleFactor, canvas.getHeight() - copy.getHeight() * scaleFactor); canvas.drawBitmap(copy, matrix, new Paint());
final Vector<String> lines = new Vector<String>(); if (detector != null) { final String statString = detector.getStatString(); final String[] statLines = statString.split("\n"); for (final String line : statLines) { lines.add(line); } } lines.add("");
lines.add("Frame: " + previewWidth + "x" + previewHeight); lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); lines.add("Rotation: " + sensorOrientation); lines.add("Inference time: " + lastProcessingTimeMs + "ms");
borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines); } }); }
OverlayView trackingOverlay;
//THis is the METHOD that will be called to Process an Image @Override protected void processImage() { ++timestamp; final long currTimestamp = timestamp; byte[] originalLuminance = getLuminance(); tracker.onFrame( previewWidth, previewHeight, getLuminanceStride(), sensorOrientation, originalLuminance, timestamp); trackingOverlay.postInvalidate();
// No mutex needed as this method is not reentrant. if (computingDetection) { readyForNextImage(); return; } computingDetection = true; LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread.");
//grab the pixels of image frame camera currently captured into rgbFrameBitmap rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
if (luminanceCopy == null) { luminanceCopy = new byte[originalLuminance.length]; } System.arraycopy(originalLuminance, 0, luminanceCopy, 0, originalLuminance.length); readyForNextImage();
final Canvas canvas = new Canvas(croppedBitmap); canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); //put rgbFrameBitmap pixels into canvas of croppedBitmap // For examining the actual TF input. if (SAVE_PREVIEW_BITMAP) { ImageUtils.saveBitmap(croppedBitmap); }
runInBackground( new Runnable() { @Override public void run() { LOGGER.i("Running detection on image " + currTimestamp); final long startTime = SystemClock.uptimeMillis(); final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap); //detect on croppedBitmap lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); final Canvas canvas = new Canvas(cropCopyBitmap); final Paint paint = new Paint(); paint.setColor(Color.RED); paint.setStyle(Style.STROKE); paint.setStrokeWidth(2.0f);
float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; switch (MODE) { case TF_OD_API: minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; break; case MULTIBOX: minimumConfidence = MINIMUM_CONFIDENCE_MULTIBOX; break; case YOLO: minimumConfidence = MINIMUM_CONFIDENCE_YOLO; break; }
final List<Classifier.Recognition> mappedRecognitions = new LinkedList<Classifier.Recognition>();
//cycle through recognition results and display on drawing canvas a box for location and info about identity for (final Classifier.Recognition result : results) { final RectF location = result.getLocation(); if (location != null && result.getConfidence() >= minimumConfidence) { canvas.drawRect(location, paint);
cropToFrameTransform.mapRect(location); result.setLocation(location); mappedRecognitions.add(result); } }
tracker.trackResults(mappedRecognitions, luminanceCopy, currTimestamp); trackingOverlay.postInvalidate();
requestRender(); computingDetection = false; } }); }
@Override protected int getLayoutId() { return R.layout.camera_connection_fragment_tracking; }
@Override protected Size getDesiredPreviewFrameSize() { return DESIRED_PREVIEW_SIZE; }
@Override public void onSetDebug(final boolean debug) { detector.enableStatLogging(debug); } }
|
TensorFlowObjectDetectionAPIModel -- used to create instance of Object Detector API classifier (which is an SSD CCN model) --- notice that the method recognizeImage() does the work.
package org.tensorflow.demo;
import android.content.res.AssetManager; import android.graphics.Bitmap; import android.graphics.RectF; import android.os.Trace; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.PriorityQueue; import java.util.Vector; import org.tensorflow.Graph; import org.tensorflow.Operation; import org.tensorflow.contrib.android.TensorFlowInferenceInterface; import org.tensorflow.demo.env.Logger;
/** * Wrapper for frozen detection models trained using the Tensorflow Object Detection API: * github.com/tensorflow/models/tree/master/research/object_detection */ public class TensorFlowObjectDetectionAPIModel implements Classifier { private static final Logger LOGGER = new Logger();
// Only return this many results. private static final int MAX_RESULTS = 100;
// Config values. private String inputName; private int inputSize;
// Pre-allocated buffers. private Vector<String> labels = new Vector<String>(); private int[] intValues; private byte[] byteValues; private float[] outputLocations; private float[] outputScores; private float[] outputClasses; private float[] outputNumDetections; private String[] outputNames;
private boolean logStats = false;
private TensorFlowInferenceInterface inferenceInterface;
/** * Initializes a native TensorFlow session for classifying images. * * @param assetManager The asset manager to be used to load assets. * @param modelFilename The filepath of the model GraphDef protocol buffer. * @param labelFilename The filepath of label file for classes. */ public static Classifier create( final AssetManager assetManager, final String modelFilename, final String labelFilename, final int inputSize) throws IOException { final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel();
InputStream labelsInput = null; String actualFilename = labelFilename.split("file:///android_asset/")[1]; labelsInput = assetManager.open(actualFilename); BufferedReader br = null; br = new BufferedReader(new InputStreamReader(labelsInput)); String line; while ((line = br.readLine()) != null) { LOGGER.w(line); d.labels.add(line); } br.close();
d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
final Graph g = d.inferenceInterface.graph();
d.inputName = "image_tensor"; // The inputName node has a shape of [N, H, W, C], where // N is the batch size // H = W are the height and width // C is the number of channels (3 for our purposes - RGB) final Operation inputOp = g.operation(d.inputName); if (inputOp == null) { throw new RuntimeException("Failed to find input Node '" + d.inputName + "'"); } d.inputSize = inputSize; // The outputScoresName node has a shape of [N, NumLocations], where N // is the batch size. final Operation outputOp1 = g.operation("detection_scores"); if (outputOp1 == null) { throw new RuntimeException("Failed to find output Node 'detection_scores'"); } final Operation outputOp2 = g.operation("detection_boxes"); if (outputOp2 == null) { throw new RuntimeException("Failed to find output Node 'detection_boxes'"); } final Operation outputOp3 = g.operation("detection_classes"); if (outputOp3 == null) { throw new RuntimeException("Failed to find output Node 'detection_classes'"); }
// Pre-allocate buffers. d.outputNames = new String[] {"detection_boxes", "detection_scores", "detection_classes", "num_detections"}; d.intValues = new int[d.inputSize * d.inputSize]; d.byteValues = new byte[d.inputSize * d.inputSize * 3]; d.outputScores = new float[MAX_RESULTS]; d.outputLocations = new float[MAX_RESULTS * 4]; d.outputClasses = new float[MAX_RESULTS]; d.outputNumDetections = new float[1]; return d; }
private TensorFlowObjectDetectionAPIModel() {}
@Override public List<Recognition> recognizeImage(final Bitmap bitmap) { // Log this method so that it can be analyzed with systrace. Trace.beginSection("recognizeImage");
Trace.beginSection("preprocessBitmap"); // Preprocess the image data from 0-255 int to normalized float based // on the provided parameters. bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < intValues.length; ++i) { byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF); byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF); byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF); } Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow. Trace.beginSection("feed"); inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3); Trace.endSection();
// Run the inference call. Trace.beginSection("run"); inferenceInterface.run(outputNames, logStats); Trace.endSection();
// Copy the output Tensor back into the output array. Trace.beginSection("fetch"); outputLocations = new float[MAX_RESULTS * 4]; outputScores = new float[MAX_RESULTS]; outputClasses = new float[MAX_RESULTS]; outputNumDetections = new float[1]; inferenceInterface.fetch(outputNames[0], outputLocations); inferenceInterface.fetch(outputNames[1], outputScores); inferenceInterface.fetch(outputNames[2], outputClasses); inferenceInterface.fetch(outputNames[3], outputNumDetections); Trace.endSection();
// Find the best detections. final PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>( 1, new Comparator<Recognition>() { @Override public int compare(final Recognition lhs, final Recognition rhs) { // Intentionally reversed to put high confidence at the head of the queue. return Float.compare(rhs.getConfidence(), lhs.getConfidence()); } });
// Scale them back to the input size. for (int i = 0; i < outputScores.length; ++i) { final RectF detection = new RectF( outputLocations[4 * i + 1] * inputSize, outputLocations[4 * i] * inputSize, outputLocations[4 * i + 3] * inputSize, outputLocations[4 * i + 2] * inputSize); pq.add( new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], detection)); }
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) { //pq is priority queue in order of confidence recognitions.add(pq.poll()); } Trace.endSection(); // "recognizeImage" return recognitions; }
@Override public void enableStatLogging(final boolean logStats) { this.logStats = logStats; }
@Override public String getStatString() { return inferenceInterface.getStatString(); }
@Override public void close() { inferenceInterface.close(); } }
|
java (org.tensorflow.demo) / Classifier.java which is a "helper" class that represents a TensorFlow Classifier --will load a pretrianed model
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/
package org.tensorflow.demo;
import android.graphics.Bitmap; import android.graphics.RectF; import java.util.List;
/** * Generic interface for interacting with different recognition engines. */ public interface Classifier { /** * An immutable result returned by a Classifier describing what was recognized. */ public class Recognition { /** * A unique identifier for what has been recognized. Specific to the class, not the instance of * the object. */ private final String id;
/** * Display name for the recognition. */ private final String title;
/** * A sortable score for how good the recognition is relative to others. Higher should be better. */ private final Float confidence;
/** Optional location within the source image for the location of the recognized object. */ private RectF location;
public Recognition( final String id, final String title, final Float confidence, final RectF location) { this.id = id; this.title = title; this.confidence = confidence; this.location = location; }
public String getId() { return id; }
public String getTitle() { return title; }
public Float getConfidence() { return confidence; }
public RectF getLocation() { return new RectF(location); }
public void setLocation(RectF location) { this.location = location; }
@Override public String toString() { String resultString = ""; if (id != null) { resultString += "[" + id + "] "; }
if (title != null) { resultString += title + " "; }
if (confidence != null) { resultString += String.format("(%.1f%%) ", confidence * 100.0f); }
if (location != null) { resultString += location + " "; }
return resultString.trim(); } }
List<Recognition> recognizeImage(Bitmap bitmap);
void enableStatLogging(final boolean debug); String getStatString();
void close(); }
|
|