How to create an IRIS classifier on Android using Eclipse Deeplearning4j.
The example application trains a small neural network on the device using Anderson’s Iris data set for iris flower type classification. For a more indepth look at optimizing android for DL4J, please see the Prerequisites and Configuration documentation here. This application has a simple UI to take measurements of petal length, petal width, sepal length, and sepal width from the user and returns the probability that the measurements belong to one of three types of Iris (Iris serosa, Iris versicolor, and Iris virginica). A data set includes 150 measurement values (50 for each iris type) and training the model takes anywhere from 5-20 seconds, depending on the device.
Setting up the neural network on a background thread
Training even a simple neural network like in this example requires a significant amount of processor power, which is in limited supply on mobile devices. Thus, it is imperative that a background thread be used for the building and training of the neural network which then returns the output to the main thread for updating the UI. In this example we will be using an AsyncTask which accepts the input measurements from the UI and passes them as type double to the doInBackground() method. First, lets get references to the editTexts in the UI layout that accept the iris measurements inside of our onCreate method. Then an onClickListener will execute our asyncTask, pass it the measurements entered by the user, and show a progress bar until we hide it again in onPostExecute().
publicclassMainActivityextendsAppCompatActivity {@OverridepublicvoidonCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState);setContentView(R.layout.activity_main);//get references to the editTexts that take the measurementsfinalEditText PL = (EditText) findViewById(R.id.editText);finalEditText PW = (EditText) findViewById(R.id.editText2);finalEditText SL = (EditText) findViewById(R.id.editText3);finalEditText SW = (EditText) findViewById(R.id.editText4);//onclick to capture the input and launch the asyncTaskButton button = (Button) findViewById(R.id.button);button.setOnClickListener(new View.OnClickListener() { @OverridepublicvoidonClick(View v) {finaldouble pl =Double.parseDouble(PL.getText().toString());finaldouble pw =Double.parseDouble(PW.getText().toString());finaldouble sl =Double.parseDouble(SL.getText().toString());finaldouble sw =Double.parseDouble(SW.getText().toString());AsyncTaskRunner runner =newAsyncTaskRunner();//pass the measurement as params to the AsyncTaskrunner.execute(pl,pw,sl,sw);ProgressBar bar = (ProgressBar) findViewById(R.id.progressBar);bar.setVisibility(View.VISIBLE); } }); }
Now let’s write our AsyncTask<Params, Progress, Results>. The AsyncTask needs to have a Params of type Double to receive the decimal value measurements from the UI. The Result type is set to INDArray, which is returned from the doInBackground() Method and passed to the onPostExecute() method for updating the UI. NDArrays are provided by the ND4J library and are essentially n-dimensional arrays with a given number of dimensions. For more on NDArrays, see https://nd4j.org/userguide.
privateclassAsyncTaskRunnerextendsAsyncTask<Double,Integer,INDArray> {// Runs in UI before background thread is called @OverrideprotectedvoidonPreExecute() { super.onPreExecute();ProgressBar bar = (ProgressBar) findViewById(R.id.progressBar);bar.setVisibility(View.INVISIBLE); }
Preparing the training data set and user input
The doInBackground() method will handle the formatting of the training data, the construction of the neural net, the training of the net, and the analysis of the input data by the trained model. The user input has only 4 values, thus we can add those directly to a 1x4 INDArray using the putScalar() method. The training data is much larger and must be converted from CSV lists to matrices through an iterative for loop.
The training data is stored in the app as two arrays, one for the Iris measurements named irisData which contains a list of 150 iris measurements and another for the labels of iris type named labelData. These will be transformed to 150x4 and 150x3 matrices, respectively, so that they can be converted into INDArray objects that the neural network will use for training.
// This is our main background thread for the neural net @OverrideprotectedStringdoInBackground(Double...params) {//Get the doubles from params, which is an array so they will be 0,1,2,3double pld = params[0];double pwd = params[1];double sld = params[2];double swd = params[3];//Create input INDArray for the user measurementsINDArray actualInput =Nd4j.zeros(1,4);actualInput.putScalar(newint[]{0,0}, pld);actualInput.putScalar(newint[]{0,1}, pwd);actualInput.putScalar(newint[]{0,2}, sld);actualInput.putScalar(newint[]{0,3}, swd);//Convert the iris data into 150x4 matrixint row=150;int col=4;double[][] irisMatrix=newdouble[row][col];int i =0;for(int r=0; r<row; r++){for( int c=0; c<col; c++){ irisMatrix[r][c]=com.example.jmerwin.irisclassifier.DataSet.irisData[i++]; } }//Now do the same for the label dataint rowLabel=150;int colLabel=3;double[][] twodimLabel=newdouble[rowLabel][colLabel];int ii =0;for(int r=0; r<rowLabel; r++){for( int c=0; c<colLabel; c++){ twodimLabel[r][c]=com.example.jmerwin.irisclassifier.DataSet.labelData[ii++]; } }//Converting the data matrices into training INDArrays is straight forwardINDArray trainingIn =Nd4j.create(irisMatrix);INDArray trainingOut =Nd4j.create(twodimLabel);
Building and Training the Neural Network
Now that our data is ready, we can build a simple multi-layer perceptron with a single hidden layer. The DenseLayer class is used to create the input layer and the hidden layer of the network while the OutputLayer class is used for the Output layer. The number of columns in the input INDArray must equal to the number of neurons in the input layer (nIn). The number of neurons in the hidden layer input must equal the number inputLayer’s output array (nOut). Finally, the outputLayer input should match the hiddenLayer output. The output must equal the number of possible classifications, which is 3.
//define the layers of the networkDenseLayer inputLayer =new DenseLayer.Builder().nIn(4).nOut(3).name("Input").build();DenseLayer hiddenLayer =new DenseLayer.Builder().nIn(3).nOut(3).name("Hidden").build();OutputLayer outputLayer =new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(3).nOut(3).name("Output").activation(Activation.SOFTMAX).build();
The next step is to build the neural network using nccBuilder. The parameters selected below for training are standard. To learn more about optimizing network training, see deeplearning4j.org.
NeuralNetConfiguration.Builder nncBuilder =new NeuralNetConfiguration.Builder();long seed =6;nncBuilder.seed(seed);nncBuilder.activation(Activation.TANH);nncBuilder.weightInit(WeightInit.XAVIER);NeuralNetConfiguration.ListBuilder listBuilder =nncBuilder.list();listBuilder.layer(0, inputLayer);listBuilder.layer(1, hiddenLayer);listBuilder.layer(2, outputLayer);MultiLayerNetwork myNetwork =newMultiLayerNetwork(listBuilder.build());myNetwork.init();//Create a data set from the INDArrays and train the networkDataSet myData =newDataSet(trainingIn, trainingOut);for(int l=0; l<=1000; l++) {myNetwork.fit(myData); }//Evaluate the input data against the modelINDArray actualOutput =myNetwork.output(actualInput);Log.d("myNetwork Output ",actualOutput.toString());//Here we return the INDArray to onPostExecute where it can be //used to update the UIreturn actualOutput;}
Updating the UI
Once the training of the neural network and the classification of the user measurements are complete, the doInBackground() method will finish and onPostExecute() will have access to the main thread and UI, allowing us to update the UI with the classification results. Note that the decimal places reported on the probabilities can be controlled by setting a DecimalFormat pattern.
//This is where we update the UI with our classification results @OverrideprotectedvoidonPostExecute(INDArray result) { super.onPostExecute(result);//Hide the progress bar now that we are finishedProgressBar bar = (ProgressBar) findViewById(R.id.progressBar);bar.setVisibility(View.INVISIBLE);//Retrieve the three probabilitiesDouble first =result.getDouble(0,0);Double second =result.getDouble(0,1);Double third =result.getDouble(0,2);//Update the UI with outputTextView setosa = (TextView) findViewById(R.id.textView11);TextView versicolor = (TextView) findViewById(R.id.textView12);TextView virginica = (TextView) findViewById(R.id.textView13);//Limit the double to values to two decimals using DecimalFormatDecimalFormat df2 =newDecimalFormat(".##");//Set the text of the textViews in UI to show the probabilitessetosa.setText(String.valueOf(df2.format(first)));versicolor.setText(String.valueOf(df2.format(second)));virginica.setText(String.valueOf(df2.format(third))); }
Conclusion
Hopefully this tutorial has illustrated how the compatibility of DL4J with Android makes it easy to build, train, and evaluate neural networks on mobile devices. We used a simple UI to take input values from the measurement and then passed them as the Params in an AsyncTask. The processor intensive steps of data preparation, network layer building, model training, and evaluation of the user data were all performed in the doInBackground() method of the background thread, maintaining a stable and responsive device. Once completed, we passed the output INDArray as the AsyncTask Results to onPostExecute() where the UI was updated to demonstrate the classification results. The limitations of processing power and battery life of mobile devices make training robust, multi-layer networks somewhat unfeasible. To address this limitation, we will next look at an example Android application that saves the trained model on the device for faster performance after an initial model training.
The complete code for this example is available here.