Importing TensorFlow models
Currently SameDiff supports the import of TensorFlow frozen graphs through the various SameDiff.importFrozenTF methods. TensorFlow documentation on frozen models can be found here.
import org.nd4j.autodiff.SameDiff.SameDiff;
SameDiff sd = SameDiff.importFrozenTF(modelFile);
After you import the TensorFlow model there are 2 ways to find the inputs and outputs. The first method is to look at the output of
sd.summary();
Where the input variables are the output of no ops, and the output variables are the input of no ops. Another way to find the inputs is
List<String> inputs = sd.inputs();
To run inference use:
INDArray out = sd.batchOutput()
.input(inputName, inputArray)
.output(outputs)
.execSingle();
For multiple outputs, use
exec()
instead of execSingle()
, to return a Map<String,INDArray>
of outputs instead. Alternatively, you can use methods such as SameDiff.output(Map<String, INDArray> placeholders, String... outputs)
to get the same output.We have a TensorFlow graph analyzing utility which will report any missing operations (operations that still need to be implemented) here
It is possible to remove nodes from the network. For example TensorFlow 1.x models can have hard coded dropout layers. See the BERT Graph test for an example.