The built in model import framework is an extensible way to implement framework conversion to the nd4j format. It's possible to create mappings from one framework's file format to the nd4j file format.
Conceptually, to import a model from a different framework, all a user has to do is include the appropriate module from the samediff-import-* maven coordinates. A core api is provided, plus a module for each framework supported. Implementing custom framework importers is also very easy to do (to be explained in another page)
A user underneath the covers may also provide placeholders to be used when running import, otherwise just provide an empty map for your variables. When a framework importer is created, it will scan the classpath for definitions of tensorflow ops, custom rules for importing nodes for specific ops or specific node names, and nd4j op descriptors. These elements of the model import framework are all customizable, but included by default for a fairly easy out of the box experience.
For implementing your own custom overrides please see here for an example.
A brief explanation is below:
Annotate the class with a PrehookRule as in the above example. This will enable the runtime to discover your custom import.
When scanning, the framework will look through the annotations for calls to intercept and use your framework call. It will intercept nodes with certain names (nodeNames), op names (ops with a name, ensure this is the op name in the framework you are trying to import)
When annotating also specify the framework name (usually onnx and tensorflow, but you can also create custom frameworks as well )
Afterwards, write the samediff calls to be the equivalent calls in what you might find in the framework. Usually samediff will have the op calls needed to implement any missing op you should need. If you need help, please ask on the forums: https://community.konduit.ai/
Lastly, when return a hook result, as in what's at the bottom of the sample always know whether you return true or false for continuing the normal import process. That matters for ensuring that if you are implementing a whole op in the hook then it should return false, otherwise the hook can also be used as an addon.
Implementing custom samediff ops
When implementing custom import calls, there are generally a few things of note:
The samediff instance that gets passed in is the one to be used for final output. Please consider how this may affect other parts of the graph when directly manipulating the graph itself.
The op passed in will contain all information for input variables that were resolved from the node currently being imported. In order to access ndarrays specified on the op, you can use sd.getVariable(..)
You may need to remove variables and ops if the original import is going to be replaced. This is currently a manual process and will be automated at a later date where possible. If you need help on whether you should add or remove certain op calls or variables, please feel free to ping us on the forums: https://community.konduit.ai/
When needed, controlling this underlying experience allows users to configure the model import to work for their use case rather than having to rely on a software upgrade for missing operations. Many cutting edge models or operators can be supported by directly composing the ops within the samediff framework.
When converting a model, a user should do this outside of production and save it as a samediff flatbuffers model. This is so end users can control the load times (especially for larger models)
In order to save a model, a user may call save as follows:
The second boolean parameter just covers whether to save the state of the training or not. If you are retraining your model, set it to true, otherwise false is fine.
In order to use this, you must also include an nd4j backend. The reason for this is because when nd4j tries to create an ndarray, it needs to know what chip its operating on to allocate memory.
Once the graph is loaded in memory, you can use it as any normal samediff graph.
For seeing what's in the graph use:
Where the input variables are the output of no ops, and the output variables are the input of no ops. Another way to find the inputs is
List<String> inputs = graph.inputs();
To run inference use:
INDArray out = graph.batchOutput()
For multiple outputs, use exec() instead of execSingle(), to return a Map<String,INDArray> of outputs instead. Alternatively, you can use methods such as SameDiff.output(Map<String, INDArray> placeholders, String... outputs) to get the same output.