MultiLayerNetwork And ComputationGraph
DL4J provides the following classes to configure networks:
MultiLayerNetwork
ComputationGraph
MultiLayerNetwork
consists of a single input layer and a single output layer with a stack of layers in between them.
ComputationGraph
is used for constructing networks with a more complex architecture than MultiLayerNetwork
. It can have multiple input layers, multiple output layers and the layers in between can be connected through a direct acyclic graph.
Network Configurations
Whether you create MultiLayerNetwork
or ComputationGraph
, you have to provide a network configuration to it through NeuralNetConfiguration.Builder
. As the name implies, it provides a Builder pattern to configure a network. To create a MultiLayerNetwork
, we build a MultiLayerConfiguraion
and for ComputationGraph
, it’s ComputationGraphConfiguration
.
The pattern goes like this: [High Level Configuration] -> [Configure Layers] -> [Build Configuration]
Required imports
Building a MultiLayerConfiguration
What we did here?
High Level Configuration
Configuration of Layers
Here we are calling list() to get the ListBuilder
. It provides us the necessary api to add layers to the network through the layer(arg1, arg2)
function.
The first parameter is the index of the position where the layer needs to be added.
The second parameter is the type of layer we need to add to the network.
To build and add a layer we use a similar builder pattern as:
Building a Graph
Finally, the last build()
call builds the configuration for us.
Sanity checking for our MultiLayerConfiguration
You can get your network configuration as String, JSON or YAML for sanity checking. For JSON we can use the toJson()
function.
Creating a MultiLayerNetwork
Finally, to create a MultiLayerNetwork
, we pass the configuration to it as shown below
Building a ComputationGraphConfiguration
What we did here?
The only difference here is the way we are building layers. Instead of calling the list()
function, we call the graphBuilder()
to get a GraphBuilder
for building our ComputationGraphConfiguration
. Following table explains what each function of a GraphBuilder
does.
The output layers defined here use another function lossFunction
to define what loss function to use.
Sanity checking for our ComputationGraphConfiguration
You can get your network configuration as String, JSON or YAML for sanity checking. For JSON we can use the toJson()
function
Creating a ComputationGraph
Finally, to create a ComputationGraph
, we pass the configuration to it as shown below
More MultiLayerConfiguration Examples
Regularization
Dropout connects
Bias initialization
More ComputationGraphConfiguration Examples
Recurrent Network
with Skip Connections
Multiple Inputs and Merge Vertex
Multi-Task Learning
Last updated