Saving and Loading Models
Saving and loading of neural networks.
MultiLayerNetwork and ComputationGraph both have save and load methods.
You can save/load a MultiLayerNetwork using:
MultiLayerNetwork net = ...
net.save(new File("...");
MultiLayerNetwork net2 = MultiLayerNetwork.load(new File("..."), true);
Similarly, you can save/load a ComputationGraph using:
ComputationGraph net = ...
net.save(new File("..."));
ComputationGraph net2 = ComputationGraph.load(new File("..."), true);
Internally, these methods use the
ModelSerializer
class, which handles loading and saving models. There are two methods for saving models shown in the examples through the link. The first example saves a normal multi layer network, the second one saves a computation graph.Here is a basic example with code to save a computation graph using the
ModelSerializer
class, as well as an example of using ModelSerializer to save a neural net built using MultiLayer configuration.If your model uses probabilities (i.e. DropOut/DropConnect), it may make sense to save it separately, and apply it after model is restored; i.e:
Nd4j.getRandom().setSeed(12345);
ModelSerializer.restoreMultiLayerNetwork(modelFile);
This will guarantee equal results between sessions/JVMs.
Utility class suited to save/restore neural net models
writeModel
public static void writeModel(@NonNull Model model, @NonNull File file, boolean saveUpdater) throws IOException
Write a model to a file
- param model the model to write
- param file the file to write to
- param saveUpdater whether to save the updater or not
- throws IOException
writeModel
public static void writeModel(@NonNull Model model, @NonNull File file, boolean saveUpdater,DataNormalization dataNormalization) throws IOException
Write a model to a file
- param model the model to write
- param file the file to write to
- param saveUpdater whether to save the updater or not
- param dataNormalization the normalizer to save (optional)
- throws IOException
writeModel
public static void writeModel(@NonNull Model model, @NonNull String path, boolean saveUpdater) throws IOException
Write a model to a file path
- param model the model to write
- param path the path to write to
- param saveUpdater whether to save the updater or not
- throws IOException
writeModel
public static void writeModel(@NonNull Model model, @NonNull OutputStream stream, boolean saveUpdater)
throws IOException
Write a model to an output stream
- param model the model to save
- param stream the output stream to write to
- param saveUpdater whether to save the updater for the model or not
- throws IOException
writeModel
public static void writeModel(@NonNull Model model, @NonNull OutputStream stream, boolean saveUpdater,DataNormalization dataNormalization)
throws IOException
Write a model to an output stream
- param model the model to save
- param stream the output stream to write to
- param saveUpdater whether to save the updater for the model or not
- param dataNormalization the normalizer ot save (may be null)
- throws IOException
restoreMultiLayerNetwork
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file) throws IOException
Load a multi layer network from a file
- param file the file to load from
- return the loaded multi layer network
- throws IOException
restoreMultiLayerNetwork
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater)
throws IOException
Load a multi layer network from a file
- param file the file to load from
- return the loaded multi layer network
- throws IOException
restoreMultiLayerNetwork
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
throws IOException
Load a MultiLayerNetwork from InputStream from an input stream
Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
- param is the inputstream to load from
- return the loaded multi layer network
- throws IOException
- see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
restoreMultiLayerNetwork
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is) throws IOException
Restore a multi layer network from an input stream
Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
- param is the input stream to restore from
- return the loaded multi layer network
- throws IOException
- see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
restoreMultiLayerNetwork
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull String path) throws IOException
Load a MultilayerNetwork model from a file
- param path path to the model file, to get the computation graph from
- return the loaded computation graph
- throws IOException
restoreMultiLayerNetwork
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull String path, boolean loadUpdater)
throws IOException
Load a MultilayerNetwork model from a file
- param path path to the model file, to get the computation graph from
- return the loaded computation graph
- throws IOException
restoreComputationGraph
public static ComputationGraph restoreComputationGraph(@NonNull String path) throws IOException
Restore a MultiLayerNetwork and Normalizer (if present - null if not) from the InputStream. Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
- param is Input stream to read from
- param loadUpdater Whether to load the updater from the model or not
- return Model and normalizer, if present
- throws IOException If an error occurs when reading from the stream
restoreComputationGraph
public static ComputationGraph restoreComputationGraph(@NonNull String path, boolean loadUpdater)
throws IOException
Load a computation graph from a file
- param path path to the model file, to get the computation graph from
- return the loaded computation graph
- throws IOException
restoreComputationGraph
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater)
throws IOException
Load a computation graph from a InputStream
- param is the inputstream to get the computation graph from
- return the loaded computation graph
- throws IOException
restoreComputationGraph
public static ComputationGraph restoreComputationGraph(@NonNull InputStream is) throws IOException
Load a computation graph from a InputStream
- param is the inputstream to get the computation graph from
- return the loaded computation graph
- throws IOException
restoreComputationGraph
public static ComputationGraph restoreComputationGraph(@NonNull File file) throws IOException
Load a computation graph from a file
- param file the file to get the computation graph from
- return the loaded computation graph
- throws IOException
restoreComputationGraph
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException
Restore a ComputationGraph and Normalizer (if present - null if not) from the InputStream. Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
- param is Input stream to read from
- param loadUpdater Whether to load the updater from the model or not
- return Model and normalizer, if present
- throws IOException If an error occurs when reading from the stream
taskByModel
public static Task taskByModel(Model model)
- param model
- return
addNormalizerToModel
public static void addNormalizerToModel(File f, Normalizer<?> normalizer)
This method appends normalizer to a given persisted model.
PLEASE NOTE: File should be model file saved earlier with ModelSerializer
- param f
- param normalizer
addObjectToFile
public static void addObjectToFile(@NonNull File f, @NonNull String key, @NonNull Object o)
Add an object to the (already existing) model file using Java Object Serialization. Objects can be restored using {- link #getObjectFromFile(File, String)}
- param f File to add the object to
- param key Key to store the object under
- param o Object to store using Java object serialization
Last modified 9mo ago