Tutorial: Classifier
How to create an IRIS classifier on Android using Eclipse Deeplearning4j.
The example application trains a small neural network on the device using Anderson’s Iris data set for iris flower type classification. For a more indepth look at optimizing android for DL4J, please see the Prerequisites and Configuration documentation here. This application has a simple UI to take measurements of petal length, petal width, sepal length, and sepal width from the user and returns the probability that the measurements belong to one of three types of Iris (Iris serosa, Iris versicolor, and Iris virginica). A data set includes 150 measurement values (50 for each iris type) and training the model takes anywhere from 5-20 seconds, depending on the device.
Setup
For detailed setup instructions see the Deeplearning4J Android Setup section.
Setting up the neural network on a background thread
Training even a simple neural network like in this example requires a significant amount of processor power, which is in limited supply on mobile devices. Thus, it is imperative that a background thread be used for the building and training of the neural network which then returns the output to the main thread for updating the UI. In this example we will be using an AsyncTask which accepts the input measurements from the UI and passes them as type double to the doInBackground() method. First, lets get references to the editTexts in the UI layout that accept the iris measurements inside of our onCreate method. Then an onClickListener will execute our asyncTask, pass it the measurements entered by the user, and show a progress bar until we hide it again in onPostExecute().
Now let’s write our AsyncTask<Params, Progress, Results>. The AsyncTask needs to have a Params of type Double to receive the decimal value measurements from the UI. The Result type is set to INDArray, which is returned from the doInBackground() Method and passed to the onPostExecute() method for updating the UI. NDArrays are provided by the ND4J library and are essentially n-dimensional arrays with a given number of dimensions. For more on NDArrays, see https://nd4j.org/userguide.
Preparing the training data set and user input
The doInBackground() method will handle the formatting of the training data, the construction of the neural net, the training of the net, and the analysis of the input data by the trained model. The user input has only 4 values, thus we can add those directly to a 1x4 INDArray using the putScalar() method. The training data is much larger and must be converted from CSV lists to matrices through an iterative for loop.
The training data is stored in the app as two arrays, one for the Iris measurements named irisData which contains a list of 150 iris measurements and another for the labels of iris type named labelData. These will be transformed to 150x4 and 150x3 matrices, respectively, so that they can be converted into INDArray objects that the neural network will use for training.
Building and Training the Neural Network
Now that our data is ready, we can build a simple multi-layer perceptron with a single hidden layer. The DenseLayer class is used to create the input layer and the hidden layer of the network while the OutputLayer class is used for the Output layer. The number of columns in the input INDArray must equal to the number of neurons in the input layer (nIn). The number of neurons in the hidden layer input must equal the number inputLayer’s output array (nOut). Finally, the outputLayer input should match the hiddenLayer output. The output must equal the number of possible classifications, which is 3.
The next step is to build the neural network using nccBuilder. The parameters selected below for training are standard. To learn more about optimizing network training, see deeplearning4j.org.
Updating the UI
Once the training of the neural network and the classification of the user measurements are complete, the doInBackground() method will finish and onPostExecute() will have access to the main thread and UI, allowing us to update the UI with the classification results. Note that the decimal places reported on the probabilities can be controlled by setting a DecimalFormat pattern.
Conclusion
Hopefully this tutorial has illustrated how the compatibility of DL4J with Android makes it easy to build, train, and evaluate neural networks on mobile devices. We used a simple UI to take input values from the measurement and then passed them as the Params in an AsyncTask. The processor intensive steps of data preparation, network layer building, model training, and evaluation of the user data were all performed in the doInBackground() method of the background thread, maintaining a stable and responsive device. Once completed, we passed the output INDArray as the AsyncTask Results to onPostExecute() where the UI was updated to demonstrate the classification results. The limitations of processing power and battery life of mobile devices make training robust, multi-layer networks somewhat unfeasible. To address this limitation, we will next look at an example Android application that saves the trained model on the device for faster performance after an initial model training.
The complete code for this example is available here.
Last updated