How to add differential functions and other ops to SameDiff graph.
Notes to write on: Rewrite for new op descriptors
To get started with SameDiff, familiarize yourself with the autodiff
module of the ND4J API located here on GitHub.
For better or worse, SameDiff code is organized in just a few key places. For basic usage and testing of SameDiff the following modules are key. We'll discuss some of them in more detail in just a bit.
functions
: This module has the basic building blocks to build SameDiff variables and graphs.
execution
: has everything related to SameDiff graph execution.
gradcheck
: Utility functionality for checking SameDiff gradients, similar in structure to the respective tool in DL4J.
loss
: Loss functions for SameDiff
samediff
: Main SameDiff module to define, set up and run SameDiff operations and graphs.
functions
moduleSee the functions
module on GitHub.
The central abstraction of the functions
module is DifferentialFunction
, which underlies pretty much everything in SameDiff. Mathematically, what we're doing in SameDiff is build a directed acyclic graph whose nodes are differential functions, for which we can compute gradients. In that regard, DifferentialFunction
makes up a SameDiff graph on a fundamental level.
Note that each DifferentialFunction
comes with a SameDiff
instance. We'll discuss SameDiff
and this relationship later on. Also, while there's only few key abstractions, they're essentially used everywhere, so it's almost impossible to discuss SameDiff concepts separately. Eventually we'll get around to each part.
Each differential function comes with properties. In the simplest case, a differential function just has a name. Depending on the operation in question, you'll usually have many more properties (think strides or kernel sizes in convolutions). When we import computation graphs from other projects (TensorFlow, ONNX, etc.) these properties need to be mapped to the conventions we're using internally. The methods attributeAdaptersForFunction
, mappingsForFunction
, propertiesForFunction
and resolvePropertiesFromSameDiffBeforeExecution
are what you want to look at to get started.
A differential function is executed on a list of inputs, using function properties, and produces one or more output variables. You have access to many helper functions to set or access these variables:
args()
: returns all input variables.
arg()
: returns the first input variable (the only one for unary operations).
larg()
and rarg()
: return the first and second (read "left" and "right") argument for binary operations
outputVariables()
: returns a list of all output variables. Depending on the operation, this may be computed dynamically. As we'll see later on, to get the result for ops with a single output, we'll call .outputVariables()[0]
.
Handling output variables is tricky and one of the pitfalls in using and extending SameDiff. For instance, implementing calculateOutputShape
for a differential function might be necessary, but if implemented incorrectly can lead to hard-to-debug failures. (Note that SameDiff will eventually call op execution in libnd4j
and dynamic custom ops either infer output shapes or need to be provided with the correct output shape.)
Automatic differentiation for a differential functions is implemented in a single method: doDiff
. Each operation has to provide an implementation of doDiff
. If you're implementing a SameDiff operation for a libnd4j
op x
and you're lucky to find x_bp
(as in "back-propagation") you can use that and your doDiff
implementation comes essentially for free.
You'll also see a diff
implementation that's used internally and calls doDiff
.
Importantly, each differential function has access to a factory, an instance of DifferentialFunctionFactory
, by calling f()
. More precisely, this will return the factory of the SameDiff instance the differential function has:
This is used in many places and gives you access to all differential functions currently registered in SameDiff. Think of this factory as a provider of operations. Here's an example of exposing sum
to the DifferentialFunctionFactory
:
We leave out the function arguments on purpose here. Note that all we do is redirect to the Sum
operation defined elsewhere in ND4J and then return the first output variable (of type SDVariable
, discussed in a second). Disregarding the implementation details for now, what this allows you to do is call f().sum(...)
from anywhere you have access to a differential function factory. For instance, when implementing a SameDiff op x
and you already have x_bp
in your function factory, you can override doDiff
for x
samediff
See the samediff
module on GitHub.
Not surprisingly, this is where the magic happens. This module has the core structures that SameDiff operates with. First, let's have a look at the variables that make up SameDiff operations.
SDVariable
(read SameDiff variable) extends DifferentialFunction
and is to SameDiff what INDArray
is to good old ND4J. In particular, SameDiff graphs operate on these variables and each individual operation takes in and spits out a list of SDVariable
. An SDVariable
comes with a name, is equipped with a SameDiff
instance, has shape information and knows how to initialize itself with an ND4J WeightInitScheme
. You'll also find a few helpers to set and get these properties.
One of the few things an SDVariable
can do that a DifferentialFunction
can't it evaluate its result and return an underlying INDArray
by calling eval()
. This will run SameDiff internally and retrieve the result. A similar getter is getArr()
which you can call at any point to get the current value of this variable. This functionality is used extensively in testing, to assert proper results. An SDVariable
also has access to its current gradient through gradient()
. Upon initialization there won't be any gradient, it will usually be computed at a later point.
Apart from these methods, SDVariable
also carries methods for concrete ops (and is in that regard a little similar to DifferentialFunctionFactory
). For instance, defining add
as follows:
allows you to call c = a.add(b)
on two SameDiff variables, the result of which can be accessed by c.eval()
.
The SameDiff
class is the main workhorse of the module and brings together most of the concepts discussed so far. A little unfortunately, the inverse is also true and SameDiff
instances are part of all other SameDiff module abstractions in some way or the other (which is why you've seen it many times already). Generally speaking, SameDiff
is the main entry point for automatic differentiation and you use it to define a symbolic graph that carries operations on SDVariable
s. Once built, a SameDiff graph can be run in a few ways, for instance exec()
and execAndEndResult()
.
Convince yourself that invoking SameDiff()
sets up a million things! Essentially, SameDiff
will collect and give you access (in terms of both getters and setters) to
All differential functions for the graph, with all their properties, which can be accessed in various ways (e.g. name or id).
All inputs and output information for said functions.
All function properties and how to map them, propertiesToResolve
and propertiesForFunction
are of particular note.
SameDiff
is also the place where you expose new operations to the SameDiff module. Essentially, you write a little wrapper for the respective operation in the DifferentialFunctionFactory
instance f()
. Here's an example for cross products:
At this point it might be a good idea to check out and run a few examples. SameDiff tests are a good source for that. Here's an example of how to multiply two SameDiff variables
This example is taken from SameDiffTests, one of the main test sources, in which you also find a few complete end-to-end examples.
The second place you find tests is in samediff repo directory. Whenever you add a new operation to SameDiff, add tests for the forward pass and gradient checks as well.
The third set of relevant tests is stored in imports and contains test for importing TensorFlow and ONNX graphs. On a side note, the resources for these import tests are generated in our TFOpsTests project.
We've seen how ND4J operations get picked up by DifferentialFunctionFactory
and SameDiff
to expose them to SameDiff at various levels. As for actually implementing these ops, you need to know a few things. In libnd4j you find two classes of operations, which are described here in detail. We'll show how to implement both op types.
All operations go here, and most of the time it's obvious where exactly to put the ops. Special attention goes to layers
, which is reserved for deep learning layer implementations (like Conv2D
). These higher-level ops are based on the concept of Modules, similar to modules in pytorch or layers in TensorFlow. These layer op implementation also provide a source of more involved op implementations.
Legacy (or XYZ) operations are the old breed of ND4J operations with a characteristic "xyz" signature. Here's how to implement cosine in ND4J by wrapping the cos
legacy op from libn4j: Cosine implementation. When it comes to SameDiff, the good thing about legacy ops is that they're already available in ND4J, but need to be augmented by SameDiff specific functionality to pass the muster. Since the cosine function does not have any properties, this implementation is straightforward. The parts that make this op SameDiff compliant are:
You specify SameDiff constructors here
You implement doDiff
here.
You specify a SameDiff opName
, a TensorFlow tensorflowName
and an ONNX onnxName
here.
If you look closely, this is only part of the truth, since Cos
extends BaseTransformOp
, which implements other SameDiff functionality. (Note that BaseTransformOp
is a BaseOp
, which extends DifferentialFunction
from earlier.) For instance, calculateOutputShape
is implemented there. If you want to implement a new transform, you can simply inherit from BaseTransformOp
, too. For other op types like reductions etc. there are op base classes available as well, meaning you only need to address the three bullet points above.
In the rare case you need to write a legacy op from scratch, you'll have to find the respective op number from libn4j, which can be found in legacy_ops.h
.
DynamicCustomOp
is the new kind of operation from libnd4j and all recent additions are implemented as such. This operation type in ND4J directly extends DifferentialFunction
.
Here's an example of the BatchToSpace
operation, which inherits from DynamicCustomOp
:
BatchToSpace is initialized with two properties, blocks
and crops
. Note how blocks
and crops
, which are both of integer type, get added to integer arguments for the operation by calling addIArgument
. For float arguments and other types, use addTArgument
instead.
The operation gets its own name and names for import,
and doDiff
is implemented.
The BatchToSpace operation is then integrated into DifferentialFunctionFactory
here, exposed to SameDiff
here and tested here.
Let's look at another operation that does property mapping right, namely DynamicPartition
. This op has precisely one property, called numPartitions
in SameDiff. To map and use this property, you do the following:
Implement a little helper method called addArgs
that is used in the constructor of the op and in an import helper one-liner that we're discussing next. It's not necessary, but encouraged to do this and call it addArgs
consistently, for clarity.
Note that while DynamicPartition
has proper property mapping, it currently does not have a working doDiff
implementation.
As a last example, we show one that has a little more interesting property mapping setup, namely Dilation2D
. Not only has this op far more properties to map, as you can see in mappingsForFunction
, the properties also come with property values, as defined in attributeAdaptersForFunction
. We've chosen to show this op because it is one that has property mapping, but is neither exposed to DifferentialFunctionFactory
not SameDiff
.
Hence, the three DynamicCustomOp
examples shown each come with their own defects and represent examples of the work that has to be done for SameDiff. To summarize, to add a new SameDiff op you need to:
In order to properly add ops, we have a set of code generation modules under contrib. Read the basics on the module here. It is recommended to use this model when adding new ops to samediff. After the op is added in libnd4j it is recommended to use the code gen module to update the samediff api.
After updates, a few constructors for your op class defined above may need to be added. Leverage your IDE to add the necessary constructors. The constructors usually added fall in to 2 camps usually:
Samediff SDVariable based constructors. All you need to do for these constructors generally is add something like: super(samediff,new SDVariable[]{..},null)
a. This will invoke a super constructor with samediff and the necessary inputs.
NDArray based constructor with: super(new INDArray[]{..},null) representing to just the inputs and unspecified inputs.
Similar to the code generation module above, we also generate the specs/op definitions found here. The op spec definition generator can be found here. Similar to the above module, you will want to run the op parser generator and update the nd4j-api module's src/main/resources directory defined in the same space as linked prior.
Upon running the main class, an op-ir.proto will appear in the root directory. Copy and paste the contents of that file in to the nd4j-api module. For now, it's recommended to just import the class in to your IDE and run the main class like that.
This project contains the ND4J Op definitions, the DSL (Domain Specific Language) that is used for those definitions and code generators that use those definitions to create the actual Java code that is used to use the defined operations.
As we started to support SameDiff, we also started to introduce inconsistencies between SameDiff and ND4J. Even though both of those libraries use the same underlying implementations for operations, there are both small and large differences in the API that we provide for them. Sometimes, we have provided an official API only for one usage, and not the other. And very often the documentation for a single op is in many different places.
In the future we want to support other programming languages with libnd4j, and provide more ways to use our C++ backend. This would only increase the aforementioned problems.
The root of all of those problems, is that Ops are used across different environments, and there is no single way of defining them with an enforced interface.
The solution we propose, is to define the operations separately, and then generate the necessary API code for them. All of the generated code is to be considered untouchable, editing it will result in the changes being overwritten sooner rather than later.
The combination of external op definition and code generation, opens up many opportunities for us. The first one being that we can easily create consistent APIs for both ND4J and SameDiff in Java. But, looking into the future, we can also create those APIs for other programming languages like Python, Swift, or even C#. We can even go beyond programming languages, and use the op definitions to create better documentation than what JavaDoc or similar might support out of the box.
This project is currently maintained by Paul Dubs, with feedback often collected from raver119 and Alex Black.
At the moment we still focus on nailing down an easily readable and contribution friendly DSL for op definition and code generation that can replace namespace definitions. This means that at the moment we still rely on the pre-existing Op definition classes that already exist in ND4J.
Replace Bitwise and Random namespaces with autogenerated code – In progress.
Implement a convenient CLI tool.
Define all Ops using the DSL.
Automatically generate derivative op declarations from existing ops
Replace all namespace definitions in ND4J / SameDiff with automatically generated ones
Replace all Op classes with automatically generated ones.
Pre-requisites:
JDK 8 or higher
Maven 3.3 or higher
TODO: Show usage output of the project itself
TODO: Show how to use from mvn
A script - generate.sh
- is provided in the project root. This can be used (at present) to generate ND4J namespace classes. It is assumed that the deeplearning4j mono repo and the dl4j-dev-tools repo both exist and have a common parent directory i.e., somedir/deeplearning4j
and somedir/dl4j-dev-tools
both exist.
The script takes as argument the name (or names) of the ND4J namespaces to generate (not case sensitive) and projects (supported projects are nd4j, sd and both by default).
As of 26/11, namespaces names (and hence valid args) include: bitwise
, neuralnetwork
, random
, and math
Note also that all
may be passed to the script to generate all namespaces.
For example, to generate both bitwise and random namespaces for both nd4j and SameDiff:
Or to generate all namespaces for both nd4j and SameDiff, use:
To generate namespaces for one project only, use:
or:
The script will first compile the project, before running. Internally, the org.nd4j.codegen.cli.CLI
class is used. Classes are written to deeplearning4j/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/
It is possible to use generate.sh for generation of code only, docs in markdown format only, or both docs and code. To generate docs only and store them to new folder "docs" for all namespaces:
Generation for selected namespaces works in the same way as for code:
The project is implemented using a mix of Java and Kotlin. The DSL definition and the accompanying data structures are implemented in Kotlin. At the moment the code generators are implemented in Java, in order to allow people who are not fluent in Kotlin, but know Java to be able to contribute to the code generators.
The source code for this project is structured a bit different that what you would typically see in a Java or Kotlin project. When you take a look inside the src/main
directory, you will find 4 sub-directories.
The java
and kotlin
directories contain Java and Kotlin code respectively.
In order to not confuse op definitions with the machinery that allows them to be defined in that way, ops are kept in a separate folder called ops
.
Because we use JavaPoet for Java code generator implementation, we also have yet another folder called stubs
. That folder contains stub classes, that are used to reference other classes available in ND4J. These stub classes are intentionally left empty, as JavaPoet only requires them for naming and automatically creating proper imports. We use stub classes instead of depending on the actual nd4j API in order to break a cyclic dependency that would otherwise be created (i.e. in order to be able to generate code for ND4J, we would need an already compiled nd4j to be available). Note: If something is stubbed here and is moved in ND4J, then it also has to be moved to the appropriate place here, otherwise the generated code will be wrong.
The adr
folder contains "Architecture Decision Records". These files give you more insight into the "why" of some of the bigger decisions within this project.
Ops are defined using a DSL that is implemented in Kotlin. This means that other than the DSL, as defined in the following, you can also use all of Kotlin when defining Ops. However, doing things the obvious and clearly understandable way is better than coming up with a clever way, so prefer to use the DSL as described if unsure.
This example shows how a namespace is defined. Namespaces are at the top layer, and ops can only be defined within the context of a namespace. This example namespace contains only a single op, called "add". If we wanted to add another op, we would simply add it below the first.
As you can see, every op has to have a name, if you try to create one without a name, you will get a compile error. Within the context of the op, we first set in which java package the op class can be found in, then define its inputs, arguments and outputs and finally add some free form documentation about what that op is doing.
Like with the op itself, the inputs, arguments and outputs all have to have a name, but unlike the op, they also require a type. Within their context, you can set a description and a count of how many parameters they can take respectively.
If an input, argument or output take anything else than exactly 1, they will be treated as arrays. Typically you would use this to define ops like concat
which can take multiple input tensors or ops that might take shape arguments.
The following shows how a typical op definition looks like and how the generated Java code may look.
An op might be defined like this:
The java code generator will create a method like the following for it:
Or an op with some more constraints:
will be converted to java like this:
Defines a namespace.
Only available within a namespace context
Every op requires a namespace unique op name.
When defining an op, you can also pass a mixin that it should inherit initial properties from. This has the same effect as using useMixin(mixin)
as the very first thing in the op definition. If you don't want to inherit all of the parameters of the mixin, you can pass the same additional configuration as you would pass to useMixin(mixin, ...options..)
. See Mixin for more information.
javaPackage
(String): Package where the op is to be found in the java implementation.
javaOpClass
(String): Name of java op class if inconsistent with opName. Default: same as opName
libnd4jName
(String): The name the op has in libnd4j. Default: same as opName
Available in global context.
Mixins provide the facility to share commonalities between Ops. You can think of it like inheritance, especially when you declare the use of a mixin on Op definition. In contrast to normal (single) inheritance where only a single super class is possible, the mixin mechanism allows to "inherit" from multiple sources.
You can define almost all the same things within a mixin that you can within an Op. The only things that can not be configured within a mixin are Op name
, libnd4jName
and javaOpClass
.
As mixins can be configured within the global context, you can share them across namespaces by defining them in their own file. If a mixin is namespace specific, you can also define it within the namespace context.
Mixins are used either on definition as a parameter Op("opname", mixin){...}
, or with useMixin(mixin)
within the op definition. While the former version only supports a single mixin, the latter version allows you to use as many mixins as are required.
You can also build up mixins by using useMixin(mixin)
inside a Mixin itself.
useMixin(mixin, ...options...)
supports a few additional options: keepInputs
, keepArgs
, keepConfigs
, keepOutputs
, keepSignatures
, keepDoc
, keepConstraints
. They default to true
. If you want to skip including some of them, you simply set the parameter for it to false
, e.g. useMixin(mixin, keepDoc=false)
.
When using useMixin(mixin)
, all definitions within the mixin are applied as if this invocation was replaced with the content of the mixin itself. This means, that if you have already defined anything prior to using a mixin, the mixin's definitions will be after the previously defined things. This can be very useful if the commonality between ops is that they have a few trailing options.
If a named property or section is defined in both a mixin (or multiple mixins) and the op, then the last to define it will win. Named properties are legacy
, javaPackage
, named sections are Input
, Arg
, Output
, Config
.
For example, assume you have javaPackage
defined in both an op and a mixin. Then you can have the following two cases:
First case:
Second case:
In the first case, the op will have the javaPackage
value that is defined within the op. In the second case it will have the javaPackage
value defined in the mixin.
For inputs, args, outputs, it works similarly. Assume you have Input(dataType, "a")
defined in both the mixin and the op. Again you can have two cases:
First case:
Second case:
In the first case, it will overwrite the input from the mixin. In the second case, the mixin will overwrite that the input from the op.
Only available within a namespace context
Every config requires a namespace unique name.
A config allows to define a configuration class, that can be used as a holder for complex properties of specific ops which will be passed to an op as a parameter.
Similar to an op itself, it supports Input
, Arg
, Constraint
and Doc
definitions.
in order to use the config within an op you either use useConfig(cfg)
or val configRef = useConfig(cfg)
. The second form allows you to reference the config.
Referencing the config allows to you reference its inputs and args by name: configRef.input("name")
and configRef.arg("name")
. Also it allows you to use a config in a signature Signature(a, b, c, configRef)
.
When default and shorthand signatures are used, configs will be always placed at the end.
If a config is defined but not used, an IllegalStateException
will be thrown.
See also ADR 0007 "Configuration Objects".
Available within an op, mixin and a config context
Inputs represent tensors. They are what the op will work on.
Every input requires a data type (either INT
, FLOATING_POINT
, NUMERIC
or BOOLEAN
) and an op unique name.
When defining an input, you can assign it to a variable in order to be able to reference it later on. You might want to do this when defining constraints.
If you want an input to represent an array, you will have to set a count accordingly. If no count is set, it is assumed that the count is meant to be Exactly(1)
.
description
(String): A short description what this input represents. Setting this is recommended.
count
(Count): Can take one of Exactly(n)
; AtLeast(n)
; AtMost(n)
; Range(from, to)
defaultValue
(Input): use another input as the default if this isn't set explicitly. The data type of the other input has to match the data type of this input. The other input may also have a default value.
Available within an op, mixin and config context
Args represent arguments. They modify how the op works on its inputs.
Every arg requires a data type (either INT
, FLOATING_POINT
, NUMERIC
or BOOLEAN
) and an op unique name.
When defining an arg, you can assign it to a variable in order to be able to reference it later on. You might want to do this when defining constraints.
If you want an arg to represent an array, you will have to set a count accordingly. If no count is set, it is assumed that the count is meant to be Exactly(1)
.
Note (Java specific): If the last arg is defined to represent an array, it will be translated to a vararg parameter, e.g. Arg(INT, "a"){ count = AtLeast(1); description = "..." }
will be turned into long... a
.
description
(String): A short description what this argument represents. Setting this is recommended.
count
(Count): Can take one of Exactly(n)
; AtLeast(n)
; AtMost(n)
; Range(from, to)
defaultValue
(null|Number|Boolean|int[]|double[]|boolean[]|Arg|TensorShapeValue|TensorDataTypeValue|String): Use given value as default value, if this isn't explicitly set. Can refer to inputs and outputs using x.shape()
and x.dataType()
. The given default values has to match the data type for this argument. May also refer to another Arg, and that Arg may also have a default value. Default values based on outputs are treated like without a default in SameDiff mode.
possibleValues
(String[]): only available when ENUM data type is used for the argument. Takes a list of possible values for the Enum. If used in in abstract base op, the enum will only be created once. See also ADR 0006 "Op specific enums".
Only available within an op and mixin context
Every output requires a data type (either INT
, FLOATING_POINT
, NUMERIC
or BOOLEAN
) and an op unique name.
While outputs can be assigned to a variable, there is no intended use-case for it. In contrast to inputs and args, outputs can not be used in constraints.
description
(String): A short description what this argument represents. Setting this is recommended.
Only available within an op and mixin context
For some ops only specific signatures make sense, as for example some optional parameters may become required in the presence of other optional parameters. This feature is mainly meant to help with the fact that not all programming languages (e.g. Java) support default parameters. Each signature is meant to describe one overload in those languages.
See also ADR 0005 "Optional parameters and signatures".
Signatures can also reference the output(s) of an op. Those signatures are only relevant in NDArray programming mode. They are not to be generated in SameDiff mode.
AllParamSignature()
and AllDefaultParamSignature()
are short hands for Signature(...all parameters...)
and Signature(...only parameters with no default values...)
. Their parameters include references to outputs unless disabled using withOutput=false
(e.g. AllParamSignature(withOutput=false)
).
If no signature is specified for an op, it is treated as if AllParamSignature()
and AllDefaultParamSignature()
are both specified.
Each signature must satisfy the condition, that all required parameters are listed there. If this condition is not satisfied, an IllegalStateException
will be thrown on construction.
Only available within an op and mixin context
Documentation can be language specific, and can be set to be only given at specific places. The documentation itself is given as a string. Because Kotlin supports multiline strings along with proper indentation, we are using them directly here.
Note: At the moment we are only creating java code, so the documentation can use JavaDoc syntax.
You can have multiple Doc definitions; they are treated as additive.
Any instances of the following values will be replaced when generating code:
%OPNAME%
-> operation name ("Add", "Sub", etc)
%LIBND4J_OPNAME%
-> libnd4j op name ("add", "sub", etc)
%INPUT_TYPE%
-> input / output type depending on the generated api, i.e. SDVariable
for SameDiff and INDArray
for ND4J
See DocTokens
class for more details.
Available within an op, mixin and a config context.
Many ops expect their inputs and arguments to satisfy some specific rules. Those rules can be expressed with the constraint system.
Constraints are to be enforced within the frontend language, while BackendConstraints are currently only to be used as a part of the documentation. They will be enforced within the C++ backend, so there is no point in double checking them.
There is a system in place to define even complex constraints for inputs and arguments.
In a constraint definition, you can reference inputs and arguments directly, if they are previously assigned to a variable using val name = Input(...)
. Inside the Constraint block, you can use the following operations:
eq
: Compare equality (applicable to numbers and booleans), e.g. x eq 7
, x eq true
neq
: Compare inequality (applicable to numbers and booleans), e.g. x neq 3
, x neq true
lt
, lte
: less than, less than equal (applicable to numbers), e.g. x lt 3
, x lte 4
gt
, gte
: greater than, grater than equal (applicable to numbers), e.g. x gt 5
, x gte 6
and
: combine two comparisons where both have to be true, e.g. (x eq 8) and (y lt 3)
or
: combine two comparisons where one has to be true, e.g. (x eq 8) or (y eq true)
all
: combine N comparisons where all have to be true, e.g. all(x eq 8, y lt 3, z eq true)
some
: combine N comparisons where at least one has to be true, e.g. some(x eq 8, y lt 3, z eq true)
not
: negates a comparison, e.g. not(x eq 3)
In addition to those operations, you also get access to some more complex constraints:
sameType(...)
: true if all given inputs are the same type, e.g. sameType(x,y,z)
sameShape(...)
: true if all given inputs have the same shape, e.g. sameShape(x,y,z)
broadcastableShapes(...)
: true if all given inputs have broadcast compatible shapes, e.g. broadcastableShapes(x,y,z)
Inputs also get some additional methods on them to define useful constraints:
input.rank()
: Rank of the given input
input.sizeAt(i)
: size of the given input at the i-th dimension
input.isScalar()
: Short hand for x.rank() == 1
Some examples of constraints, and what they evaluate to. The example code contains a little bit of context.
will evaluate to:
More examples (only the constraint itself, without context code):
Some
turns to:
If you want to contribute to this project other than by adding or improving op definitions, the following sections might be of special interest to you.
The DSL is implemented using Kotlin’s type-safe builders feature (see https://kotlinlang.org/docs/reference/type-safe-builders.html). The basic principle is that functions calls can receive blocks that can be executed in a specified context. When combined with the fact that we are just looking to create an object graph that is then going to be used as input to the code generators, this allows us to create a very feature rich DSL without actually having to write a lot of code to support it.
Most of the DSL specific code can be found in src/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt
. The actual class definitions for the object graph we are building, can be found in src/kotlin/org/nd4j/codegen/api
.
If you want to add just a simple field to one of the objects, it is usually enough to directly add it to the particular class.
If you want to add a specific section to the op definition, i.e. a section like Input or Doc, you will have to add both the class for the object that it is going to be creating, as well as a function within OpBuilder.kt to create and register that section within the op.
Note: When you extend the DSL you will most likely also have to update all code generators to support the feature you have added.
Code generators can be written in either Java or Kotlin. Java has the advantage that more people will have experience in using it. Kotlin has the advantage of more convenient syntax, especially for plain string manipulation and when dealing with Enums and fixed sets of subclasses (called sealed classes in Kotlin).
All generators have to implement the org.nd4j.codegen.api.generator.Generator
interface. For automatic detection by the CLI tool, they should also be within the org.nd4j.codegen.impl.LANGUAGE
package, where LANGUAGE
is the actual language that they generate.
Code generators can also use an auxiliary generator for constraint generation. Those auxiliary generators, have to implement org.nd4j.codegen.api.generator.ConstraintCodeGenerator
interface.