JavaSparkContext sc = ...;
JavaRDD<DataSet> trainingData = ...;
//Model setup as on a single node. Either a MultiLayerConfiguration or a ComputationGraphConfiguration
MultiLayerConfiguration model = ...;
// Configure distributed training required for gradient sharing implementation
VoidConfiguration conf = VoidConfiguration.builder()
.unicastPort(40123) //Port that workers will use to communicate. Use any free port
.networkMask(“10.0.0.0/16”) //Network mask for communication. Examples 10.0.0.0/24, or 192.168.0.0/16 etc
.controllerAddress("10.0.2.4") //IP of the master/driver
//Create the TrainingMaster instance
TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf)
.batchSizePerWorker(batchSizePerWorker) //Batch size for training
.updatesThreshold(1e-3) //Update threshold for quantization/compression. See technical explanation page
.workersPerNode(numWorkersPerNode) // equal to number of GPUs. For CPUs: use 1; use > 1 for large core count CPUs
.meshBuildMode(MeshBuildMode.MESH) // or MeshBuildMode.PLAIN for < 32 nodes
//Create the SparkDl4jMultiLayer instance
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, model, trainingMaster);
for (int i = 0; i < numEpochs; i++) {
sparkNet.fit(trainingData);