Loading...
Loading...
Loading...
Loading...
Loading...
How to visualize, monitor and debug neural network learning.
Note: This information here pertains to DL4J versions 1.0.0-beta6 and later.
DL4J Provides a user interface to visualize in your browser (in real time) the current network status and progress of training. The UI is typically used to help with tuning neural networks - i.e., the selection of hyperparameters (such as learning rate) to obtain good performance for a network.
Step 1: Add the Deeplearning4j UI dependency to your project.
Step 2: Enable the UI in your project
This is relatively straightforward:
To access the UI, open your browser and go to http://localhost:9000/train/overview
. You can set the port by using the org.deeplearning4j.ui.port
system property: i.e., to use port 9001, pass the following to the JVM on launch: -Dorg.deeplearning4j.ui.port=9001
Information will then be collected and routed to the UI when you call the fit
method on your network.
Example: See a UI example here
The full set of UI examples are available here.
The overview page (one of 3 available pages) contains the following information:
Top left: score vs iteration chart - this is the value of the loss function on the current minibatch
Top right: model and training information
Bottom left: Ratio of parameters to updates (by layer) for all network weights vs. iteration
Bottom right: Standard deviations (vs. time) of: activations, gradients and updates
Note that for the bottom two charts, these are displayed as the logarithm (base 10) of the values. Thus a value of -3 on the update: parameter ratio chart corresponds to a ratio of 10-3 = 0.001.
The ratio of updates to parameters is specifically the ratio of mean magnitudes of these values (i.e., log10(mean(abs(updates))/mean(abs(parameters))).
See the later section of this page on how to use these values in practice.
The model page contains a graph of the neural network layers, which operates as a selection mechanism. Click on a layer to display information for it.
On the right, the following charts are available, after selecting a layer:
Table of layer information
Update to parameter ratio for this layer, as per the overview page. The components of this ratio (the parameter and update mean magnitudes) are also available via tabs.
Layer activations (mean and mean +/- 2 standard deviations) over time
Histograms of parameters and updates, for each parameter type
Learning rate vs. time (note this will be flat, unless learning rate schedules are used)
Note: parameters are labeled as follows: weights (W) and biases (b). For recurrent neural networks, W refers to the weights connecting the layer to the layer below, and RW refers to the recurrent weights (i.e., those between time steps).
The DL4J UI can be used with Spark. However, as of 0.7.0, conflicting dependencies mean that running the UI and Spark is the same JVM can be difficult.
Two alternatives are available:
Collect and save the relevant stats, to be visualized (offline) at a later point
Run the UI in a separate server, and Use the remote UI functionality to upload the data from the Spark master to your UI instance
Collecting Stats for Later Offline Use
Then, later you can load and display the saved information using:
Using the Remote UI Functionality
First, in the JVM running the UI (note this is the server):
This will require the deeplearning4j-ui
dependency. (NOTE THIS IS NOT THE CLIENT THIS IS YOUR SERVER - SEE BELOW FOR THE CLIENT WHICH USES: deeplearning4j-ui-model)
Client (both spark and standalone neural networks using simple deeplearning4j-nn) Second, for your neural net (Note this example is for spark, but computation graph and multi layer network both have the equivalemtn setListeners method with the same usage, example found here):
To avoid dependency conflicts with Spark, you should use the deeplearning4j-ui-model
dependency to get the StatsListener, not the full deeplearning4j-ui
UI dependency.
Note: you should replace UI_MACHINE_IP
with the IP address of the machine running the user interface instance.
Here's an excellent web page by Andrej Karpathy about visualizing neural net training. It is worth reading and understanding that page first.
Tuning neural networks is often more an art than a science. However, here's some ideas that may be useful:
Overview Page - Model Score vs. Iteration Chart
The score vs. iteration should (overall) go down over time.
If the score increases consistently, your learning rate is likely set too high. Try reducing it until scores become more stable.
Increasing scores can also be indicative of other network issues, such as incorrect data normalization
If the score is flat or decreases very slowly (over a few hundred iterations) (a) your learning rate may be too low, or (b) you might be having difficulties with optimization. In the latter case, if you are using the SGD updater, try a different updater such as Nesterovs (momentum), RMSProp or Adagrad.
Note that data that isn't shuffled (i.e., each minibatch contains only one class, for classification) can result in very rough or abnormal-looking score vs. iteration graphs
Some noise in this line chart is expected (i.e., the line will go up and down within a small range). However, if the scores vary quite significantly between runs variation is very large, this can be a problem
The issues mentioned above (learning rate, normalization, data shuffling) may contribute to this.
Setting the minibatch size to a very small number of examples can also contribute to noisy score vs. iteration graphs, and might lead to optimization difficulties
Overview Page and Model Page - Using the Update: Parameter Ratio Chart
The ratio of mean magnitude of updates to parameters is provided on both the overview and model pages
"Mean magnitude" = the average of the absolute value of the parameters or updates at the current time step
The most important use of this ratio is in selecting a learning rate. As a rule of thumb: this ratio should be around 1:1000 = 0.001. On the (log10) chart, this corresponds to a value of -3 (i.e., 10-3 = 0.001)
Note that is a rough guide only, and may not be appropriate for all networks. It's often a good starting point, however.
If the ratio diverges significantly from this (for example, > -2 (i.e., 10-2=0.01) or < -4 (i.e., 10-4=0.0001), your parameters may be too unstable to learn useful features, or may change too slowly to learn useful features
To change this ratio, adjust your learning rate (or sometimes, parameter initialization). In some networks, you may need to set the learning rate differently for different layers.
Keep an eye out for unusually large spikes in the ratio: this may indicate exploding gradients
Model Page: Layer Activations (vs. Time) Chart
This chart can be used to detect vanishing or exploding activations (due to poor weight initialization, too much regularization, lack of data normalization, or too high a learning rate).
This chart should ideally stabilize over time (usually a few hundred iterations)
A good standard deviation for the activations is on the order of 0.5 to 2.0. Significantly outside of this range may indicate one of the problems mentioned above.
Model Page: Layer Parameters Histogram
The layer parameters histogram is displayed for the most recent iteration only.
For weights, these histograms should have an approximately Gaussian (normal) distribution, after some time
For biases, these histograms will generally start at 0, and will usually end up being approximately Gaussian
One exception to this is for LSTM recurrent neural network layers: by default, the biases for one gate (the forget gate) are set to 1.0 (by default, though this is configurable), to help in learning dependencies across long time periods. This results in the bias graphs initially having many biases around 0.0, with another set of biases around 1.0
Keep an eye out for parameters that are diverging to +/- infinity: this may be due to too high a learning rate, or insufficient regularization (try adding some L2 regularization to your network).
Keep an eye out for biases that become very large. This can sometimes occur in the output layer for classification, if the distribution of classes is very imbalanced
Model Page: Layer Updates Histogram
The layer update histogram is displayed for the most recent iteration only.
Note that these are the updates - i.e., the gradients after applying learning rate, momentum, regularization etc
As with the parameter graphs, these should have an approximately Gaussian (normal) distribution
Keep an eye out for very large values: this can indicate exploding gradients in your network
Exploding gradients are problematic as they can 'mess up' the parameters of your network
In this case, it may indicate a weight initialization, learning rate or input/labels data normalization issue
In the case of recurrent neural networks, adding some gradient normalization or gradient clipping may help
Model Page: Parameter Learning Rates Chart
This chart simply shows the learning rates of the parameters of selected layer, over time.
If you are not using learning rate schedules, the chart will be flat. If you are using learning rate schedules, you can use this chart to track the current value of the learning rate (for each parameter), over time.
The recommended solution (for Maven) is to use the Maven Shade plugin to produce an uber-jar, configured as follows:
Then, create your uber-jar with mvn package
and run via cd target && java -cp dl4j-examples-0.9.1-bin.jar org.deeplearning4j.examples.userInterface.UIExample
. Note the "-bin" suffix for the generated JAR file: this includes all dependencies.
Note also that this Maven Shade approach is configured for DL4J's examples repository.
Tools and classes for evaluating neural network performance
When training or deploying a Neural Network it is useful to know the accuracy of your model. In DL4J the Evaluation Class and variants of the Evaluation Class are available to evaluate your model's performance.
The Evaluation class is used to evaluate the performance for binary and multi-class classifiers (including time series classifiers). This section covers basic usage of the Evaluation Class.
Given a dataset in the form of a DataSetIterator, the easiest way to perform evaluation is to use the built-in evaluate methods on MultiLayerNetwork and ComputationGraph:
However, evaluation can be performed on individual minibatches also. Here is an example taken from our dataexamples/CSVExample in the project.
The CSV example has CSV data for 3 classes of flowers and builds a simple feed forward neural network to classify the flowers based on 4 measurements.
The first line creates an Evaluation object with 3 classes. The second line gets the labels from the model for our test dataset. The third line uses the eval method to compare the labels array from the testdata with the labels generated from the model. The fourth line logs the evaluation data to the console.
The output.
By default the .stats() method displays the confusion matrix entries (one per line), Accuracy, Precision, Recall and F1 Score. Additionally the Evaluation Class can also calculate and return the following values:
Confusion Matrix
False Positive/Negative Rate
True Positive/Negative
Class Counts
Display the Confusion Matrix.
Displays
Additionaly the confusion matrix can be accessed directly, converted to csv or html using.
To Evaluate a network performing regression use the RegressionEvaluation Class.
As with the Evaluation class, RegressionEvaluation on a DataSetIterator can be performed as follows:
Here is a code snippet with single column, in this case the neural network was predicting the age of shelfish based on measurements.
Print the statistics for the Evaluation.
Returns
Columns are Mean Squared Error, Mean Absolute Error, Root Mean Squared Error, Relative Squared Error, and R^2 Coefficient of Determination
When performing multiple types of evaluations (for example, Evaluation and ROC on the same network and dataset) it is more efficient to do this in one pass of the dataset, as follows:
For most users, it is simply sufficient to use the MultiLayerNetwork.evaluate(DataSetIterator)
or MultiLayerNetwork.evaluateRegression(DataSetIterator)
and similar methods. These methods will properly handle masking, if mask arrays are present.
The EvaluationBinary is used for evaluating networks with binary classification outputs - these networks usually have Sigmoid activation functions and XENT loss functions. The typical classification metrics, such as accuracy, precision, recall, F1 score, etc. are calculated for each output.
ROC (Receiver Operating Characteristic) is another commonly used evaluation metric for the evaluation of classifiers. Three ROC variants exist in DL4J:
ROC - for single binary label (as a single column probability, or 2 column 'softmax' probability distribution).
ROCBinary - for multiple binary labels
ROCMultiClass - for evaluation of non-binary classifiers, using a "one vs. all" approach
These classes have the ability to calculate the area under ROC curve (AUROC) and area under Precision-Recall curve (AUPRC), via the calculateAUC()
and calculateAUPRC()
methods. Furthermore, the ROC and Precision-Recall curves can be obtained using getRocCurve()
and getPrecisionRecallCurve()
.
The ROC and Precision-Recall curves can be exported to HTML for viewing using: EvaluationTools.exportRocChartsToHtmlFile(ROC, File)
, which will export a HTML file with both ROC and P-R curves, that can be viewed in a browser.
Note that all three support two modes of operation/calculation
Thresholded (approximate AUROC/AUPRC calculation, no memory issues)
Exact (exact AUROC/AUPRC calculation, but can require large amount of memory with very large datasets - i.e., datasets with many millions of examples)
The number of bins can be set using the constructors. Exact can be set using the default constructor new ROC()
or explicitly using new ROC(0)
Deeplearning4j also has the EvaluationCalibration class, which is designed to analyze the calibration of a classifier. It provides a number of tools for this purpose:
Counts of the number of labels and predictions for each class
Reliability diagram (or reliability curve)
Residual plot (histogram)
Histograms of probabilities, including probabilities for each class separately
Evaluation of a classifier using EvaluationCalibration is performed in a similar manner to the other evaluation classes. The various plots/histograms can be exported to HTML for viewing using EvaluationTools.exportevaluationCalibrationToHtmlFile(EvaluationCalibration, File)
.
SparkDl4jMultiLayer and SparkComputationGraph both have similar methods for evaluation:
Evaluation Classes useful for Multi-Task Network
F-beta, G-measure, Matthews Correlation Coefficient and more, see
See
Time series evaluation is very similar to the above evaluation approaches. Evaluation in DL4J is performed on all (non-masked) time steps separately - for example, a time series of length 10 will contribute 10 predictions/labels to an Evaluation object. One difference with time seires is the (optional) presence of mask arrays, which are used to mark some time steps as missing or not present. See for more details on masking.
See
See is used to evaluate Binary Classifiers.
A multi-task network is a network that is trained to produce multiple outputs. For example a network given audio samples can be trained to both predict the language spoken and the gender of the speaker. Multi-task configuration is briefly described .
See
See
Understanding common errors like NaNs and tuning hyperparameters.
Neural networks can be difficult to tune. If the network hyperparameters are poorly chosen, the network may learn slowly, or perhaps not at all. This page aims to provide some baseline steps you should take when tuning your network.
Many of these tips have already been discussed in the academic literature. Our purpose is to consolidate them in one site and express them as clearly as possible.
What's distribution of your data? Are you scaling it properly? As a general rule:
For continuous values: you want these to be in the range of -1 to 1, 0 to 1 or ditributed normally with mean 0 and standard deviation 1. This does not have to be exact, but ensuring your inputs are approximately in this range can help during training. Scale down large inputs, and scale up small inputs.
For discrete classes (and, for classification problems for the output), generally use a one-hot representation. That is, if you have 3 classes, then your data will be represeted as [1,0,0], [0,1,0] or [0,0,1] for each of the 3 classes respectively.
Note that it's very important to use the exact same normalization method for both the training data and testing data.
Deeplearning4j supports several different kinds of weight initializations with the weightInit parameter. These are set using the .weightInit(WeightInit) method in your configuration.
You need to make sure your weights are neither too big nor too small. Xavier weight initialization is usually a good choice for this. For networks with rectified linear (relu) or leaky relu activations, RELU weight initialization is a sensible choice.
An epoch is defined as a full pass of the data set.
Too few epochs don't give your network enough time to learn good parameters; too many and you might overfit the training data. One way to choose the number of epochs is to use early stopping. Early stopping can also help to prevent the neural network from overfitting (i.e., can help the net generalize better to unseen data).
The learning rate is one of, if not the most important hyperparameter. If this is too large or too small, your network may learn very poorly, very slowly, or not at all. Typical values for the learning rate are in the range of 0.1 to 1e-6, though the optimal learning rate is usually data (and network architecture) specific. Some simple advice is to start by trying three different learning rates – 1e-1, 1e-3, and 1e-6 – to get a rough idea of what it should be, before further tuning this. Ideally, they run models with different learning rates simultaneously to save time.
The usual approach to selecting an appropriate learning rate is to use DL4J's visualization interface to visualize the progress of training. You want to pay attention to both the loss over time, and the ratio of update magnitudes to parameter magnitudes (a ratio of approximately 1:1000 is a good place to start). For more information on tuning the learning rate, see this link.
For training neural networks in a distributed manner, you may need a different (frequently higher) learning rate compared to training the same network on a single machine.
You can optionally define a learning rate policy for your neural network. A policy will change the learning rate over time, achieving better results since the learning rate can "slow down" to find closer local minima for convergence. A common policy used is scheduling. See the LeNet example for a learning rate schedule used in practice.
Note that if you're using multiple GPUs, this will affect your scheduling. For example, if you have 2x GPUs, then you will need to divide the iterations in your schedule by 2, since the throughput of your training process will be double, and the learning rate schedule is only applicable to the local GPU.
There are two aspects to be aware of, with regard to the choice of activation function.
First, the activation function of the hidden (non-output) layers. As a general rule, 'relu' or 'leakyrelu' activations are good choices for this. Some other activation functions (tanh, sigmoid, etc) are more prone to vanishing gradient problems, which can make learning much harder in deep neural networks. However, for LSTM layers, the tanh activation function is still commonly used.
Second, regarding the activation function for the output layer: this is usually application specific. For classification problems, you generally want to use the softmax activation function, combined with the negative log likelihood / MCXENT (multi-class cross entropy). The softmax activation function gives you a probability distribution over classes (i.e., outputs sum to 1.0). For regression problems, the "identity" activation function is frequently a good choice, in conjunction with the MSE (mean squared error) loss function.
Loss functions for each neural network layer can either be used in pretraining, to learn better weights, or in classification (on the output layer) for achieving some result. (In the example above, classification happens in the override section.)
Your net's purpose will determine the loss function you use. For pretraining, choose reconstruction entropy. For classification, use multiclass cross entropy.
Regularization methods can help to avoid overfitting during training. Overfitting occurs when the network predicts the training set very well, but makes poor predictions on data the network has never seen. One way to think about overfitting is that the network memorizes the training data (instead of learning the general relationships in it).
Common types of regularization include:
l1 and l2 regularization penalizes large network weights, and avoids weights becoming too large. Some level of l2 regularization is commonly used in practice. However, note that if the l1 or l2 regularization coefficients are too high, they may over-penalize the network, and stop it from learning. Common values for l2 regularization are 1e-3 to 1e-6.
Dropout, is a frequently used regularization method can be very effective. Dropout is most commoly used with a dropout rate of 0.5.
Dropconnect (conceptually similar to dropout, but used much less frequently)
Restricting the total number of network size (i.e., limit the number of layers and size of each layer)
To use l1/l2/dropout regularization, use .regularization(true) followed by .l1(x), .l2(y), .dropout(z) respectively. Note that z in dropout(z) is the probability of retaining an activation.
A minibatch refers to the number of examples used at a time, when computing gradients and parameter updates. In practice (for all but the smallest data sets), it is standard to break your data set up into a number of minibatches.
The ideal minibatch size will vary. For example, a minibatch size of 10 is frequently too small for GPUs, but can work on CPUs. A minibatch size of 1 will allow a network to train, but will not reap the benefits of parallelism. 32 may be a sensible starting point to try, with minibatches in the range of 16-128 (sometimes smaller or larger, depending on the application and type of network) being common.
In DL4J, the term 'updater' refers to training mechanisms such as momentum, RMSProp, adagrad, and others. Using one of these methods can result in much faster network training companed to 'vanilla' stochastic gradient descent. You can set the updater using the .updater(Updater) configuration option.
The optimization algorithm is how updates are made, given the gradient. The simplest (and most commonly used) method is stochastic gradient descent (SGD), however DL4J also provides SGD with line search, conjugate gradient and LBFGS optimization algorithms. These latter algorithms are more powerful compared to SGD, but considerably more costly per parameter update due to a line search component, and aren't used as much in practice. Note that you can in principle combine any updater with any optimization algorithm.
A good default choice in most cases is to use the stochastic gradient descent optimization algorithm combined with one of the momentum/rmsprop/adagrad updaters, with momentum frequently being used in practice. Note that for momentum, the updater is called NESTEROVS (a reference to the Nesterovs variant of momentum), and the momentum rate can be set by the .momentum(double) option.
When training a neural network, it can sometimes be helpful to apply gradient normalization, to avoid the gradients being too large (the so-called exploding gradient problem, common in recurrent neural networks) or too small. This can be applied using the .gradientNormalization(GradientNormalization) and .gradientNormalizationThreshould(double) methods. For an example of gradient normalization see, GradientNormalization.java. The test code for that example is here.
When training recurrent networks with long time series, it is generally advisable to use truncated backpropagation through time. With 'standard' backpropagation through time (the default in DL4J) the cost per parameter update can become prohibative. For more details, see this page.
Q. Why is my Neural Network throwing nan values?
A. Backpropagation involves the multiplication of very small gradients, due to limited precision when representing real numbers values very close to zero can not be represented. The term for this issue is Arithmetic Underflow. If your Neural Network is throwing nan's then the solution is to retune your network to avoid the very small gradients. This is more likely an issue with deeper Neural Networks.
You can try using double data type but it's usually recommended to retune the net first.
Following the basic tuning tips and monitoring the results is the way to ensure NAN doesn't show up anymore.
The DL4J transfer learning API enables users to:
Modify the architecture of an existing model
Fine tune learning configurations of an existing model.
Hold parameters of a specified layer constant during training, also referred to as “frozen"
Holding certain layers frozen on a network and training is effectively the same as training on a transformed version of the input, the transformed version being the intermediate outputs at the boundary of the frozen layers. This is the process of “feature extraction” from the input data and will be referred to as “featurizing” in this document.
The forward pass to “featurize” the input data on large, pertained networks can be time consuming. DL4J also provides a TransferLearningHelper class with the following capabilities.
Featurize an input dataset to save for future use
Fit the model with frozen layers with a featurized dataset
Output from the model with frozen layers given a featurized input.
When running multiple epochs users will save on computation time since the expensive forward pass on the frozen layers/vertices will only have to be conducted once.
This example will use VGG16 to classify images belonging to five categories of flowers. The dataset will automatically download from http://download.tensorflow.org/example_images/flower_photos.tgz
Deeplearning4j has a new native model zoo. Read about the deeplearning4j-zoo module for more information on using pretrained models. Here, we load a pretrained VGG-16 model initialized with weights trained on ImageNet:
The final layer of VGG16 does a softmax regression on the 1000 classes in ImageNet. We modify the very last layer to give predictions for five classes keeping the other layers frozen.
After a mere thirty iterations, which in this case is exposure to 450 images, the model attains an accuracy > 75% on the test dataset. This is rather remarkable considering the complexity of training an image classifier from scratch.
Here we hold all but the last three dense layers frozen and attach new dense layers onto it. Note that the primary intent here is to demonstrate the use of the API, secondary to what might give better results.
Say we have saved off our model from (B) and now want to allow “block_5” layers to train.
We use the transfer learning helper API. Note this freezes the layers of the model passed in.
Here is how you obtain the featured version of the dataset at the specified layer “fc2”.
Here is how you can fit with a featured dataset. vgg16Transfer is a model setup in (A) of section III.
The TransferLearning builder returns a new instance of a dl4j model.
Keep in mind this is a second model that leaves the original one untouched. For large pertained network take into consideration memory requirements and adjust your JVM heap space accordingly.
The trained model helper imports models from Keras without enforcing a training configuration.
Therefore the last layer (as seen when printing the summary) is a dense layer and not an output layer with a loss function. Therefore to modify nOut of an output layer we delete the layer vertex, keeping it’s connections and add back in a new output layer with the same name, a different nOut, the suitable loss function etc etc.
Changing nOuts at a layer/vertex will modify nIn of the layers/vertices it fans into.
When changing nOut users can specify a weight initialization scheme or a distribution for the layer as well as a separate weight initialization scheme or distribution for the layers it fans out to.
Frozen layer configurations are not saved when writing the model to disk.
In other words, a model with frozen layers when serialized and read back in will not have any frozen layers. To continue training holding specific layers constant the user is expected to go through the transfer learning helper or the transfer learning API. There are two ways to “freeze” layers in a dl4j model.
On a copy: With the transfer learning API which will return a new model with the relevant frozen layers
In place: With the transfer learning helper API which will apply the frozen layers to the given model.
FineTune configurations will selectively update learning parameters.
For eg, if a learning rate is specified this learning rate will apply to all unfrozen/trainable layers in the model. However, newly added layers can override this learning rate by specifying their own learning rates in the layer builder.
Terminate a training session given certain conditions.
When training neural networks, numerous decisions need to be made regarding the settings (hyperparameters) used, in order to obtain good performance. Once such hyperparameter is the number of training epochs: that is, how many full passes of the data set (epochs) should be used? If we use too few epochs, we might underfit (i.e., not learn everything we can from the training data); if we use too many epochs, we might overfit (i.e., fit the 'noise' in the training data, and not the signal).
Early stopping attempts to remove the need to manually set this value. It can also be considered a type of regularization method (like L1/L2 weight decay and dropout) in that it can stop the network from overfitting.
The idea behind early stopping is relatively simple:
Split data into training and test sets
At the end of each epoch (or, every N epochs):
evaluate the network performance on the test set
if the network outperforms the previous best model: save a copy of the network at the current epoch
Take as our final model the model that has the best test set performance
This is shown graphically below:
The best model is the one saved at the time of the vertical dotted line - i.e., the model with the best accuracy on the test set.
Using DL4J's early stopping functionality requires you to provide a number of configuration options:
A score calculator, such as the DataSetLossCalculator(JavaDoc, Source Code) for a Multi Layer Network, or DataSetLossCalculatorCG (JavaDoc, Source Code) for a Computation Graph. Is used to calculate at every epoch (for example: the loss function value on a test set, or the accuracy on the test set)
How frequently we want to calculate the score function (default: every epoch)
One or more termination conditions, which tell the training process when to stop. There are two classes of termination conditions:
Epoch termination conditions: evaluated every N epochs
Iteration termination conditions: evaluated once per minibatch
A model saver, that defines how models are saved
An example, with an epoch termination condition of maximum of 30 epochs, a maximum of 20 minutes training time, calculating the score every epoch, and saving the intermediate results to disk:
You can also implement your own iteration and epoch termination conditions.
The early stopping implementation described above will only work with a single device. However, EarlyStoppingParallelTrainer
provides similar functionality as early stopping and allows you to optimize for either multiple CPUs or GPUs. EarlyStoppingParallelTrainer
wraps your model in a ParallelWrapper
class and performs localized distributed training.
Note that EarlyStoppingParallelTrainer
doesn't support all of the functionality as its single device counterpart. It is not UI-compatible and may not work with complex iteration listeners. This is due to how the model is distributed and copied in the background.