Using Multiple GPUs

Training neural network models can be a computationally expensive task. In order to speed up the training process, you can choose to train your models in parallel with multiple GPU’s if they are installed on your machine. With deeplearning4j (DL4J), this isn’t a difficult thing to do. In this tutorial we will use the MNIST dataset (dataset of handwritten images) to train a feed forward neural network in parallel with multiple GPUs.

Note: This also works if you can't fully load your CPU. In that case you just stay with the CPU specific backend.


  • You must have multiple CUDA compatible GPUs, ideally of the same speed

  • You must setup your project to use the CUDA Backend, for help see Backends


import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.deeplearning4j.parallelism.ParallelWrapper;

Data Set

To obtain the data, we use built-in DataSetIterators for the MNIST with a random seed of 12345. These DataSetIterators can be used to directly feed the data into a neural network.

val batchSize = 128
val mnistTrain = new MnistDataSetIterator(batchSize,true,12345)
val mnistTest = new MnistDataSetIterator(batchSize,false,12345)

Model Configuration

Next, we set up the neural network configuration using a convolutional configuration and initialize the model.

val nChannels = 1
val outputNum = 10
val seed = 123
val conf = new NeuralNetConfiguration.Builder()
.layer(0, new ConvolutionLayer.Builder(5, 5)
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
.stride(1, 1)
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.layer(2, new ConvolutionLayer.Builder(5, 5)
//Note that nIn need not be specified in later layers
.stride(1, 1)
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.layer(4, new DenseLayer.Builder().activation(Activation.RELU)
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.setInputType(InputType.convolutionalFlat(28,28,1)) //See note below
val model = new MultiLayerNetwork(conf)

Parallel Wrapper

Next we need to configure the parallel training with the ParallelWrapper class using the MultiLayerNetwork as the input. The ParallelWrapper will take care of load balancing between different GPUs.

The notion is that the model will be duplicated within the ParallelWrapper. The prespecified number of workers (in this case 2) will then train its own model using its data. After a specified number of iterations (in this case 3), all models will be averaged and workers will receive duplicate models. The training process will then continue in this way until the model is fully trained.

val wrapper = new ParallelWrapper.Builder(model)

To train the model, the fit method of the ParallelWrapper is used directly on the DataSetIterator. Because the ParallelWrapper class handles all the training details behind the scenes, it is very simple to parallelize this process using dl4j.