How to add differential functions and other ops to SameDiff graph.
Notes to write on: Rewrite for new op descriptors
To get started with SameDiff, familiarize yourself with the autodiff
module of the ND4J API located here on GitHub.
For better or worse, SameDiff code is organized in just a few key places. For basic usage and testing of SameDiff the following modules are key. We'll discuss some of them in more detail in just a bit.
functions
: This module has the basic building blocks to build SameDiff variables and graphs.
execution
: has everything related to SameDiff graph execution.
gradcheck
: Utility functionality for checking SameDiff gradients, similar in structure to the respective tool in DL4J.
loss
: Loss functions for SameDiff
samediff
: Main SameDiff module to define, set up and run SameDiff operations and graphs.
functions
moduleSee the functions
module on GitHub.
The central abstraction of the functions
module is DifferentialFunction
, which underlies pretty much everything in SameDiff. Mathematically, what we're doing in SameDiff is build a directed acyclic graph whose nodes are differential functions, for which we can compute gradients. In that regard, DifferentialFunction
makes up a SameDiff graph on a fundamental level.
Note that each DifferentialFunction
comes with a SameDiff
instance. We'll discuss SameDiff
and this relationship later on. Also, while there's only few key abstractions, they're essentially used everywhere, so it's almost impossible to discuss SameDiff concepts separately. Eventually we'll get around to each part.
Each differential function comes with properties. In the simplest case, a differential function just has a name. Depending on the operation in question, you'll usually have many more properties (think strides or kernel sizes in convolutions). When we import computation graphs from other projects (TensorFlow, ONNX, etc.) these properties need to be mapped to the conventions we're using internally. The methods attributeAdaptersForFunction
, mappingsForFunction
, propertiesForFunction
and resolvePropertiesFromSameDiffBeforeExecution
are what you want to look at to get started.
A differential function is executed on a list of inputs, using function properties, and produces one or more output variables. You have access to many helper functions to set or access these variables:
args()
: returns all input variables.
arg()
: returns the first input variable (the only one for unary operations).
larg()
and rarg()
: return the first and second (read "left" and "right") argument for binary operations
outputVariables()
: returns a list of all output variables. Depending on the operation, this may be computed dynamically. As we'll see later on, to get the result for ops with a single output, we'll call .outputVariables()[0]
.
Handling output variables is tricky and one of the pitfalls in using and extending SameDiff. For instance, implementing calculateOutputShape
for a differential function might be necessary, but if implemented incorrectly can lead to hard-to-debug failures. (Note that SameDiff will eventually call op execution in libnd4j
and dynamic custom ops either infer output shapes or need to be provided with the correct output shape.)
Automatic differentiation for a differential functions is implemented in a single method: doDiff
. Each operation has to provide an implementation of doDiff
. If you're implementing a SameDiff operation for a libnd4j
op x
and you're lucky to find x_bp
(as in "back-propagation") you can use that and your doDiff
implementation comes essentially for free.
You'll also see a diff
implementation that's used internally and calls doDiff
.
Importantly, each differential function has access to a factory, an instance of DifferentialFunctionFactory
, by calling f()
. More precisely, this will return the factory of the SameDiff instance the differential function has:
This is used in many places and gives you access to all differential functions currently registered in SameDiff. Think of this factory as a provider of operations. Here's an example of exposing sum
to the DifferentialFunctionFactory
:
We leave out the function arguments on purpose here. Note that all we do is redirect to the Sum
operation defined elsewhere in ND4J and then return the first output variable (of type SDVariable
, discussed in a second). Disregarding the implementation details for now, what this allows you to do is call f().sum(...)
from anywhere you have access to a differential function factory. For instance, when implementing a SameDiff op x
and you already have x_bp
in your function factory, you can override doDiff
for x
samediff
See the samediff
module on GitHub.
Not surprisingly, this is where the magic happens. This module has the core structures that SameDiff operates with. First, let's have a look at the variables that make up SameDiff operations.
SDVariable
(read SameDiff variable) extends DifferentialFunction
and is to SameDiff what INDArray
is to good old ND4J. In particular, SameDiff graphs operate on these variables and each individual operation takes in and spits out a list of SDVariable
. An SDVariable
comes with a name, is equipped with a SameDiff
instance, has shape information and knows how to initialize itself with an ND4J WeightInitScheme
. You'll also find a few helpers to set and get these properties.
One of the few things an SDVariable
can do that a DifferentialFunction
can't it evaluate its result and return an underlying INDArray
by calling eval()
. This will run SameDiff internally and retrieve the result. A similar getter is getArr()
which you can call at any point to get the current value of this variable. This functionality is used extensively in testing, to assert proper results. An SDVariable
also has access to its current gradient through gradient()
. Upon initialization there won't be any gradient, it will usually be computed at a later point.
Apart from these methods, SDVariable
also carries methods for concrete ops (and is in that regard a little similar to DifferentialFunctionFactory
). For instance, defining add
as follows:
allows you to call c = a.add(b)
on two SameDiff variables, the result of which can be accessed by c.eval()
.
The SameDiff
class is the main workhorse of the module and brings together most of the concepts discussed so far. A little unfortunately, the inverse is also true and SameDiff
instances are part of all other SameDiff module abstractions in some way or the other (which is why you've seen it many times already). Generally speaking, SameDiff
is the main entry point for automatic differentiation and you use it to define a symbolic graph that carries operations on SDVariable
s. Once built, a SameDiff graph can be run in a few ways, for instance exec()
and execAndEndResult()
.
Convince yourself that invoking SameDiff()
sets up a million things! Essentially, SameDiff
will collect and give you access (in terms of both getters and setters) to
All differential functions for the graph, with all their properties, which can be accessed in various ways (e.g. name or id).
All inputs and output information for said functions.
All function properties and how to map them, propertiesToResolve
and propertiesForFunction
are of particular note.
SameDiff
is also the place where you expose new operations to the SameDiff module. Essentially, you write a little wrapper for the respective operation in the DifferentialFunctionFactory
instance f()
. Here's an example for cross products:
At this point it might be a good idea to check out and run a few examples. SameDiff tests are a good source for that. Here's an example of how to multiply two SameDiff variables