When training neural networks, it is important to avoid overfitting the training data. Overfitting occurs when the neural network learns the noise in the training data and thus does not generalize well to data it has not been trained on. One hyperparameter that affects whether the neural network will overfit or not is the number of epochs or complete passes through the training split. If we use too many epochs, then the neural network is likely to overfit. On the other hand, if we use too few epochs, the neural network might not have the chance to learn fully from the training data.
Early stopping is one mechanism used to manually set the number of epochs to prevent underfitting and overfitting. The idea behind early stopping is intuitive. First the data is split into training and testing sets. At the end of each epoch, the neural network is evaluated on the test set. If the neural network outperforms the previous best model, then we save the neural network. The best overall model is then taken to be the final model.
In this tutorial we will show how to use early stopping with deeplearning4j (DL4J). We will apply the method on a feed forward neural network using the MNIST dataset, which is a dataset consisting of handwritten digits.
import org.apache.commons.io.FilenameUtils;import org.nd4j.linalg.activations.Activationimport org.nd4j.linalg.dataset.api.iterator.DataSetIteratorimport org.deeplearning4j.datasets.iterator.impl.MnistDataSetIteratorimport org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;import org.deeplearning4j.earlystopping.EarlyStoppingResult;import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition;import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer;import org.deeplearning4j.eval.Evaluationimport org.deeplearning4j.nn.api.OptimizationAlgorithmimport org.deeplearning4j.nn.conf.MultiLayerConfigurationimport org.deeplearning4j.nn.conf.NeuralNetConfigurationimport org.deeplearning4j.nn.conf.Updaterimport org.deeplearning4j.nn.conf.layers.DenseLayerimport org.deeplearning4j.nn.conf.layers.OutputLayerimport org.deeplearning4j.nn.multilayer.MultiLayerNetworkimport org.deeplearning4j.nn.weights.WeightInitimport org.deeplearning4j.optimize.listeners.ScoreIterationListenerimport org.nd4j.linalg.api.ndarray.INDArrayimport org.nd4j.linalg.dataset.DataSetimport org.nd4j.linalg.lossfunctions.LossFunctions.LossFunctionimport org.slf4j.Loggerimport org.slf4j.LoggerFactoryimport java.io.File;import java.util.concurrent.TimeUnit;
Now that we have imported everything needed to run this tutorial, we can start by setting the parameters for the neural network and initializing the data. We will set the maximum number of epochs to run early stopping on to be 15.
val numRows = 28val numColumns = 28val outputNum = 10val batchSize = 128val rngSeed = 123val mnistTrain: DataSetIterator = new MnistDataSetIterator(batchSize, true, rngSeed)val mnistTest: DataSetIterator = new MnistDataSetIterator(batchSize, false, rngSeed)
Next we will set the neural network configuration using the MultiLayerNetwork class of DL4J and initialize the MultiLayerNetwork.
val conf : MultiLayerConfiguration = new NeuralNetConfiguration.Builder().seed(rngSeed) //include a random seed for reproducibility// use stochastic gradient descent as an optimization algorithm.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Nesterovs(0.006)) // Nesterovs update scheme with 0.006 learning rate.l2(1e-4).list().layer(0, new DenseLayer.Builder() //create the first, input layer with xavier initialization.nIn(numRows * numColumns).nOut(1000).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer.nIn(1000).nOut(outputNum).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()).build()val model : MultiLayerNetwork = new MultiLayerNetwork(conf)
If we weren’t using early stopping, we would proceed by training the neural network using for loops and the fit method of the MultiLayerNetwork. But since we are using early stopping we need to configure how early stopping will be applied. Looking at the next cell, we will use a maximum epoch number of 10 and a maximum training time of 5 minutes. The evaluation will be done on mnistTest after each epoch. Each model will be saved in the DL4JEarlyStoppingExample directory that we specified.
Once the EarlyStoppingConfiguration is specified, we only need to initialize an EarlyStoppingTrainer using the training data and the two previous configuraitons. The results are obtained just by calling the fit method of EarlyStoppingTrainer.
val tempDir : String = System.getProperty("java.io.tmpdir")val exampleDirectory : String = FilenameUtils.concat(tempDir, "DL4JEarlyStoppingExample/")val dirFile : File = new File(exampleDirectory)dirFile.mkdir()val saver = new LocalFileModelSaver(exampleDirectory)val esConf = new EarlyStoppingConfiguration.Builder().epochTerminationConditions(new MaxEpochsTerminationCondition(10)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(5, TimeUnit.MINUTES)).scoreCalculator(new DataSetLossCalculator(mnistTest, true)).evaluateEveryNEpochs(1).modelSaver(saver).build()val trainer = new EarlyStoppingTrainer(esConf,conf,mnistTrain)val result = trainer.fit()
We can then print out the details of the best model.
println("Termination reason: " + result.getTerminationReason())println("Termination details: " + result.getTerminationDetails())println("Total epochs: " + result.getTotalEpochs())println("Best epoch number: " + result.getBestModelEpoch())println("Score at best epoch: " + result.getBestModelScore())