diff options
| -rw-r--r-- | docs/user-guide/07-autodiff.md | 122 |
1 files changed, 78 insertions, 44 deletions
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index f844254d9..05011a53e 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -14,7 +14,7 @@ In this section, we briefly walkthrough how to compute forward-derivative from i Suppose the user has already written a function that computes some mathematic term: -```C# +```csharp float myFunc(float a, float x) { return a * x * x; @@ -22,7 +22,7 @@ float myFunc(float a, float x) ``` The user can make this function *forward-differentiable* by adding a `[ForwardDerivative]` attribute: -```C# +```csharp [ForwardDifferentiable] float myFunc(float a, float x) { @@ -33,14 +33,14 @@ float myFunc(float a, float x) This allows the function to be used in the `__fwd_diff` operator, which is an higher order operation that takes in a forward-differentiable function and returns the forward-derivative of the function. The expression `__fwd_diff(myFunc)` will have the following signature: -```C# +```csharp DifferentialPair<float> myFunc_fwd_derivative(DifferentialPair<float> a, DifferentialPair<float> x); ``` Where `DifferentialPair<T>` is a builtin type that encodes both the primal(original) value and the derivative value of a term. To use this function to compute the derivative of `myFunc` with regard to `x`, the user can call the forward-derivative function by supplying the derivative value of `x` with `1.0` and the derivative value of `a` with `0.0`, as in the following code: -```C# +```csharp float a = 2.0; float x = 3.0; // Compute derivative with regard to `x`: @@ -58,7 +58,7 @@ In the example code above, `diffPair()` is a builtin function to construct a val The forward derivative function allows the user to compute the derivative of a function with regard to a specific combination of input parameters at a time. In many cases, we need to know how each parameter affects the output. Instead of calling the forward derivative function once for each parameter, it is more efficient to call the *backward propagation* function that propagate the derivative of outputs to each input parameter. To allow the compiler to generate the backward propagation function, we simply mark our function with the `[BackwardDifferentiable]` attribute: -```C# +```csharp [BackwardDifferentiable] float myFunc(float a, float x) { @@ -68,11 +68,11 @@ float myFunc(float a, float x) > #### Note: > When a function is marked as `[BackwardDifferentiable]`, it is implied that the function is also `[ForwardDifferentiable]` and can be used in the `__fwd_diff` operator. -> #### + The `__bwd_diff` operator applies to a backward differentiable function and returns the backward propagation function. In this case, `__bwd_diff(myFunc)` will have the following signature: -```C# +```csharp void myFunc_backProp(inout DifferentiablePair<float> a, inout DifferentiablePair<float> x, float dResult); ``` @@ -81,7 +81,7 @@ Where `a` is an `inout DifferentiablePair` where the initial value of `a` is pas The additional `dResult` parameter is the derivative of the return value to be propagated to the input parameters. Note that in a backward propagation function, an input will become a `inout DifferentiablePair` where the `.d` property of the pair is intended for receiving the propagation result, and the return value will become an input parameter that represents the source of backward propagation. The backward propagation function can be called as in the following code: -```C# +```csharp var a = diffPair(2.0); // constructs DifferentialPair{2.0, 0.0} var x = diffPair(3.0); // constructs DifferentialPair{3.0, 0.0} @@ -91,11 +91,11 @@ __bwd_diff(myFunc)(a, x, 1.0); // x.d is now 12.0 ``` -This completes the walkthrough of automatic differentiation features. The following sections will cover each perspective the auto differentiation feature in more detail. +This completes the walkthrough of automatic differentiation features. The following sections will cover each perspective of the auto differentiation feature in more detail. ## Differentiable Types Slang will only generate differentiation code for values that has a *differentiable* type. A type is differentiable if it conforms to the builtin `IDifferentiable` interface. The definition of the `IDifferentiable` interface is: -```Swift +```csharp interface IDifferentiable { associatedtype Differential : IDifferentiable @@ -108,7 +108,7 @@ interface IDifferentiable static Differential dmul(This, Differential); } ``` -As defined by `IDifferentiable` interface, a differentiable type must have a `Differential` associated type that stores the derivative of the value. A further requirement is that the type of a second-order derivative must be the same `Differential` type. In another word, given a type `T`, `T.Differential` can be different from `T`, but `T.Differential.Differential` must equal to `T.Differential`. +As defined by the `IDifferentiable` interface, a differentiable type must have a `Differential` associated type that stores the derivative of the value. A further requirement is that the type of a second-order derivative must be the same `Differential` type. In another word, given a type `T`, `T.Differential` can be different from `T`, but `T.Differential.Differential` must equal to `T.Differential`. In addition, a differentiable type must define the `zero` value of its derivative, and how to add and multiply derivative values. @@ -125,7 +125,7 @@ The user can make any `struct` types differentiable by implementing the `IDiffer #### Automatic Fulfillment of `IDifferentiable` Requirements Assume the user has defined the following type: -```C# +```csharp struct MyRay { float3 origin; @@ -135,7 +135,7 @@ struct MyRay ``` The type can be made differentiable by adding `IDifferentiable` conformance: -```C# +```csharp struct MyRay : IDifferentiable { float3 origin; @@ -155,7 +155,7 @@ Note that this code does not provide any explicit implementation of the `IDiffer In rare cases where more control is desired, the user can manually provide the implementation. To do so, we will first define the `Differential` type for `MyRay`, and use it to fulfill the `Differential` requirement in `MyRay`: -```Swift +```csharp struct MyRayDifferential { float3 d_origin; @@ -207,7 +207,7 @@ struct MyRay : IDifferential Note that for each struct field that is differentiable, we need to use the `[DerivativeMember]` attribute to associate it with the corresponding field in the `Differential` type, so the compiler knows how to access the derivative for the field. However, there is still a missing piece in the above code: we also need to make `MyRayDifferential` conform to `IDifferentiable` because it is required that the `Differential` of a type must itself be `Differential`. Again we can use automatic fulfillment by simply adding `IDifferentiable` conformance to `MyRayDifferential`: -```C# +```csharp struct MyRayDifferential : IDifferentiable { float3 d_origin; @@ -218,7 +218,7 @@ In this case, since all fields in `MyRayDifferential` are differentiable, and th We can also choose to manually implement `IDifferentiable` interface for `MyRayDifferential` as in the following code: -```Swift +```csharp struct MyRayDifferential : IDifferentiable { typealias Differential = MyRayDifferential; @@ -256,16 +256,16 @@ In this specific case, the automatically generated `IDifferentiable` implementat ## Forward Derivative Function -Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `__fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative of the function. Likewise, the `__bwd_diff` operator can be used on a backward-differentiable function to obtain the backward propagation function. This section covers the semantics of forward derivative and backward propagation functions, and different ways to make a function forward and backward differentiable. +Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `__fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative of the function. Likewise, the `__bwd_diff` operator can be used on a backward-differentiable function to obtain the backward propagation function. This and the next sections cover the semantics of forward derivative and backward propagation functions, and different ways to make a function forward and backward differentiable. A forward derivative function computes the derivative of the result value with regard to a specific set of input parameters. -Given an original function, the signature of its forward derivative function is computed in the following rules: +Given an original function, the signature of its forward derivative function is determined using the following rules: - If the return type `R` is differentiable, the forward derivative function will return `DifferentialPair<R>` that consists of both the computed original result value as well as the derivative of the result value. Otherwise, the return type is kept unmodified as `R`. - If a parameter has type `T` that is differentiable, it will be translated into a `DifferentialPair<T>` parameter in the derivative function to allow passing in the initial derivatives of each parameter in addition to the original values. - All parameter directions are unchanged. For example, an `out` parameter in the original function will remain an `out` parameter in the derivative function. For example, given original function: -```Swift +```csharp R original(T0 p0, inout T1 p1, T2 p2); ``` Where `R`, `T0`, and `T1` is differentiable and `T2` is non-differentiable, the forward derivative function will have the following signature: @@ -274,7 +274,7 @@ DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T ``` `DifferentialPair<T>` is a builtin type that carrys both the original and derivative value of a term. It is defined as follows: -```Swift +```csharp struct DifferentialPair<T : IDifferentiable> : IDifferentiable { typealias Differential = DifferentialPair<T.Differential>; @@ -290,19 +290,19 @@ struct DifferentialPair<T : IDifferentiable> : IDifferentiable A function can be made forward-differentiable with a `[ForwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the forward-derivative function. The syntax for using `[ForwardDifferentiable]` is: -```Swift +```csharp [ForwardDifferentiable] R original(T0 p0, inout T1, p1, T2 p2); ``` Once the function is made forward-differentiable, the forward derivative function can then be called with the `__fwd_diff` operator: -```Swift +```csharp DifferentialPair<R> result = __fwd_diff(original)(...); ``` ### User Defined Forward Derivative Functions As an alternative to compiler-implemented forward derivatives, the user can choose to manually provide an derivative implementation to make an existing function forward-differentiable. The `[ForwardDerivative(derivative_func)]` attribute is used to associate a function with its forward-derivative implementation. The syntax for using `[ForwardDerivative]` attribute is: -```Swift +```csharp DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T1> p1, T2 p2) { .... @@ -312,7 +312,7 @@ DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T R original(T0 p0, inout T1, p1, T2 p2); ``` If `derivative` is defined in a different scope from `original`, such as in a different namespace or `struct` type, a fully qualified name is required. For example: -```Swift +```csharp struct MyType { // Implementing derivative function in a different name scope. @@ -327,9 +327,9 @@ struct MyType R original(T0 p0, inout T1, p1, T2 p2); ``` -Sometimes the derivative function needs to be defined in a different module from the original function, or it is not convenient to directly modify the definition of the original function. In this case, we can use the `[ForwardDerivativeOf(originalFunnc)]` attribute to inform the compiler that `originalFunc` should be treated as a forward-differentiable function, and the current function is the derivative implementation of `originalFunc`. The following code will have the same effect to associate `derivative` and the forward-derivative implementation of `original`: +Sometimes the derivative function needs to be defined in a different module from the original function, or the derivative function cannot be made visible from the original function. In this case, we can use the `[ForwardDerivativeOf(originalFunnc)]` attribute to inform the compiler that `originalFunc` should be treated as a forward-differentiable function, and the current function is the derivative implementation of `originalFunc`. The following code will have the same effect to associate `derivative` and the forward-derivative implementation of `original`: -```Swift +```csharp R original(T0 p0, inout T1, p1, T2 p2); [ForwardDerivativeOf(original)] @@ -342,12 +342,12 @@ DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T ## Backward Propagation Function A backward propagation function propagates the derivative of the function output to all the input parameters simultaneously. -Given an orignal function, the signature of its backward propagation function is computed using the following rules: +Given an orignal function, the signature of its backward propagation function is determined using the following rules: - A back-prop function always returns `void`. - A differentiable `in` parameter of type `T` will become an `inout DifferentialPair<T>` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The propagated derivative will be written to the derivative part of the differential pair after the back-prop function returns. The initial derivative value of the pair is ignored as input. - A differentiable `out` parameter of type `T` will become an `in T.Differential` parameter, carrying the result derivative of the return value to propagate back to other input parameters. -- A differentiable `inout` parameter of type `T` will become an `inout DifferentialPair<T>` parameter, where the original value of the argument, along with the resulting derivative of the argument is passed as input to the back-prop function as the original and derivative part of the pair. The propagated derivative to this input parameter will be written back and replace the derivative part of the pair. -- A differentialbe return value of type `R` will become an additional `in R.Differential` parameter at the end of the back-prop function parameter list, carrying the result derivative of the return value to propagate back to the input parameters. +- A differentiable `inout` parameter of type `T` will become an `inout DifferentialPair<T>` parameter, where the original value of the argument, along with the resulting derivative of the argument is passed as input to the back-prop function as the original and derivative part of the pair. The propagated derivative to this input parameter will be written back and replace the derivative part of the pair. The resulting primal value will *not* be written back to the parameter. +- A differentiable return value of type `R` will become an additional `in R.Differential` parameter at the end of the back-prop function parameter list, carrying the result derivative of the return value to propagate back to the input parameters. - A non-differential return value of type `NDR` will be dropped. - A non-differential `in` parameter of type `ND` will remain unchanged in the back-prop function. - A non-differentiable `out` parameter of type `ND` will be removed from the parameter list of the back-prop function. @@ -356,7 +356,7 @@ Given an orignal function, the signature of its backward propagation function is The general rule is that any differentiable output becomes an input derivative parameter, and any non-differentiable outputs are dropped from the back-prop function. This means that the back-prop function never returns any values computed in the original function. For example consider the following original function: -```Swift +```csharp struct T : IDifferentiable {...} struct R : IDifferentiable {...} struct ND {} // Non differentiable @@ -365,7 +365,7 @@ struct ND {} // Non differentiable R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5); ``` The signature of its backward propagation function is: -```Swift +```csharp void back_prop( inout DifferentialPair<T> p0, T.Differential p1, @@ -374,18 +374,19 @@ void back_prop( ND p5, R.Differential dResult); ``` +Note that although `p2` is still `inout` in the backward propagation function, the backward propagation function will only write propagated derivative to `p2.d` and will not modify `p2.p`. ### Automatically Implemented Backward Propagation Functions A function can be made backward-differentiable with a `[BackwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the backward propagation function. The syntax for using `[BackwardDifferentiable]` is: -```Swift +```csharp [BackwardDifferentiable] R original(T0 p0, inout T1, p1, T2 p2); ``` Once the function is made backward-differentiable, the backward propagation function can then be called with the `__bwd_diff` operator: -```Swift +```csharp __bwd_diff(original)(...); ``` @@ -393,7 +394,7 @@ __bwd_diff(original)(...); Similar to user-defined forward derivative functions, the `[BackwardDerivative]` and `[BackwardDerivativeOf]` attributes can be used to supply a function with user defined backward propagation function. The syntax for using `[BackwardDerivative]` attribute is: -```Swift +```csharp void back_prop( inout DifferentialPair<T> p0, T.Differential p1, @@ -409,9 +410,9 @@ void back_prop( R original(T0 p0, inout T1, p1, T2 p2); ``` -Similarly the `[BackwardDerivativeOf]` attribute can be used on the back-prop function in case it is not convenient to modify the definition of the original function: +Similarly the `[BackwardDerivativeOf]` attribute can be used on the back-prop function in case it is not convenient to modify the definition of the original function, or the back-prop function can't be made visible from the original function: -```Swift +```csharp R original(T0 p0, inout T1, p1, T2 p2); [BackwardDerivativeOf(original)] @@ -441,43 +442,76 @@ The following builtin functions are backward differentiable and both their forwa Sometimes we do not wish a parameter to be considered differentiable despite it has a differentiable type. We can use the `no_diff` modifier on the parameter to inform the compiler to treat the parameter as non-differentiable and skip generating differentiation code for the parameter. The syntax is: -```C# +```csharp // Only differentaite this function with regard to `x`. float myFunc(no_diff float a, float x); ``` The forward derivative and backward propgation functions of `myFunc` should have the following signature: -```C# +```csharp DifferentialPair<float> fwd_derivative(float a, DifferentialPair<float> x); void back_prop(float a, inout DifferentialPair<float> x, float dResult); ``` In addition, the `no_diff` modifier can also be used on the return type to indicate the return value should be considered non-differentiable. For example, the function -```C# +```csharp no_diff float myFunc(no_diff float a, float x, out float y); ``` Will the the following forward derivative and backward propagation function signatures: -```C# +```csharp float fwd_derivative(float a, DifferentialPair<float> x); void back_prop(float a, inout DifferentialPair<float> x, float d_y); ``` +## Calling Non-Differentiable Function from a Differentiable Function +Calling non-differentiable function from a differentiable function is allowed. However, derivatives will not be propagated through the call. The user is required to clarify the intention by prefixing the call with the `no_diff` keyword. An unclarified +call to non-differnetiable function will result in a compile-time error. + +For example, consider the following code: +```csharp +float g(float x) +{ + return 2*x; +} + +[ForwardDifferentiable] +float f(float x) +{ + // Error: implicit call to non-differentiable function g. + return g(x) + x * x; +} +``` +The derivative will not propagate through the call to `g` in `f`. As a result, `__fwd_diff(f)(diffPair(1.0, 1.0))` will return +`{3.0, 2.0}` instead of `{3.0, 4.0}` as the derivative from `2*x` is lost through the non-differentiable call. To prevent unintended error, it is treated as a compile-time error to call `g` from `f`. If such a non-differentiable call is intended, a `no_diff` prefix is required in the call: +```csharp +[ForwardDifferentiable] +float f(float x) +{ + // OK. The intention to call a non-differentiable function is clarified. + return no_diff g(x) + x * x; +} +``` + +However, the `no_diff` keyword is not required in a call if a non-differentiable function does not take any differentiable parameters, or if the result of the differentiable function is not dependant on the derivative being propagated through the call. + ## Higher Order Differentiation Slang supports generating higher order forward derivative functions. It is allowed to use `__fwd_diff` operator inside a forward differentiable function, or to nest `__fwd_diff` operators. For example, `__fwd_diff(__fwd_diff(sin))` will have the following signature: -```C# +```csharp DifferentialPair<DifferentialPair<float>> sin_diff2(DifferentialPair<DifferentialPair<float>> x); ``` The input parameter `x` contains four fields: `x.p.p`, `x.p.d,`, `x.d.p`, `x.d.d`, where `x.p.p` specifies the orgiginal input value, both `x.p.d` and `x.d.p` store the first order derivative if `x`, and `x.d.d` stores the second order derivative of `x`. Calling `__fwd_diff(__fwd_diff(sin))` with `diffPair(diffPair(pi/2, 1.0), DiffPair(1.0, 0.0))` will result `{ { 1.0, 0.0 }, { 0.0, -1.0 } }`. -Currently, Slang only supports nesting of the `__fwd_diff` operator. The `__bwd_diff` operator cannot be nested. +Currently, Slang only supports nesting of the `__fwd_diff` operator. The `__bwd_diff` operator cannot be nested. Using `__bwd_diff` operator in a forward derivative or backward propagation function is now allowed and will result in compile-time error. + +User defined higher-order derivative functions can be specified by using `[ForwardDerivative]` attribute on the derivative function, or by using `[ForwardDerivativeOf]` attribute on the higher-order derivative function. ## Interactions with Generics and Interfaces -Automatic differentiation for generic functions is supported. The forward-derivative and backward propagation functions of a generic function is also a generic function with the same set of generic parameters and constraints. Using `[ForwardDerivative]`, `[ForwardDerivativeOf]`, `[BackwardDerivative]` or `[BackwardDerivativeOf]` attributes to associate a function with different set of generic parameters or constraints is a compile-time error. +Automatic differentiation for generic functions is supported. The forward-derivative and backward propagation functions of a generic function is also a generic function with the same set of generic parameters and constraints. Using `[ForwardDerivative]`, `[ForwardDerivativeOf]`, `[BackwardDerivative]` or `[BackwardDerivativeOf]` attributes to associate a derivative function with different set of generic parameters or constraints is a compile-time error. An interface method requirement can be marked as `[ForwardDifferentiable]` or `[BackwardDifferentiable]` so they may be called in a forward or backward differentiable function and have the derivatives propagate through the call. This works regardless of whether or not the call can be specialized or has to go through dynamic dispatch. However, calls to interface methods are only differentiable once. Higher order differentiation through interface method calls are not supported. @@ -490,4 +524,4 @@ The compiler can generate forward derivative and backward propagation implementa - If a differentiable function contains calls that causes side-effects such as writes to global memory, there will not be a guarantee on how many times the side-effect will occur during the resulting derivative function or back-propagation function. - All loops in a backward differentiable function must end within a statically known number of iterations. If the maximum number of iterations is not trivially deductible by the type system as a compile-time constant, a manually attribute is needed at the loop to provide the number. If the number of actually executed iterations exceeds what is being specified, the resulting runtime behavior is undefined. -The above restrictions do no apply if a user-defined derivative or backward propagation function is provided.
\ No newline at end of file +The above restrictions do no apply if a user-defined derivative or backward propagation function is provided. |
