diff options
| author | Yong He <yonghe@outlook.com> | 2023-07-12 13:02:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-12 13:02:57 -0700 |
| commit | 39b7df94b287b2115f41ca038d560102246d0696 (patch) | |
| tree | d0066f0a9bc88ecdcc04f39a167b6c04c4b1d33a /docs/user-guide/07-autodiff.md | |
| parent | d0901aa7933ac31b0bf7648a31ec5c13de864457 (diff) | |
Update autodiff documentation. (#2979)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'docs/user-guide/07-autodiff.md')
| -rw-r--r-- | docs/user-guide/07-autodiff.md | 56 |
1 files changed, 28 insertions, 28 deletions
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index dc19ba0da..b3be3fbca 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -30,9 +30,9 @@ float myFunc(float a, float x) } ``` -This allows the function to be used in the `__fwd_diff` operator, which is an higher order operation that takes in a forward-differentiable function and returns the forward-derivative of the function. +This allows the function to be used in the `fwd_diff` operator, which is an higher order operation that takes in a forward-differentiable function and returns the forward-derivative of the function. -The expression `__fwd_diff(myFunc)` will have the following signature: +The expression `fwd_diff(myFunc)` will have the following signature: ```csharp DifferentialPair<float> myFunc_fwd_derivative(DifferentialPair<float> a, DifferentialPair<float> x); ``` @@ -44,7 +44,7 @@ To use this function to compute the derivative of `myFunc` with regard to `x`, t float a = 2.0; float x = 3.0; // Compute derivative with regard to `x`: -let result = __fwd_diff(myFunc)(diffPair(a, 0.0), diffPair(x, 1.0)); +let result = fwd_diff(myFunc)(diffPair(a, 0.0), diffPair(x, 1.0)); // Print the derivative. printf("%f", result.d); @@ -57,9 +57,9 @@ In the example code above, `diffPair()` is a builtin function to construct a val The forward derivative function allows the user to compute the derivative of a function with regard to a specific combination of input parameters at a time. In many cases, we need to know how each parameter affects the output. Instead of calling the forward derivative function once for each parameter, it is more efficient to call the *backward propagation* function that propagate the derivative of outputs to each input parameter. -To allow the compiler to generate the backward propagation function, we simply mark our function with the `[BackwardDifferentiable]` attribute: +To allow the compiler to generate the backward propagation function, we simply mark our function with the `[Differentiable]` or `[BackwardDifferentiable]` attribute: ```csharp -[BackwardDifferentiable] +[Differentiable] float myFunc(float a, float x) { return a * x * x; @@ -67,10 +67,10 @@ float myFunc(float a, float x) ``` > #### Note: -> When a function is marked as `[BackwardDifferentiable]`, it is implied that the function is also `[ForwardDifferentiable]` and can be used in the `__fwd_diff` operator. +> When a function is marked as `[Differentiable]`, it is implied that the function is both `[ForwardDifferentiable]` and `[BackwardDifferentiable]` and can be used in the `fwd_diff` operator. -The `__bwd_diff` operator applies to a backward differentiable function and returns the backward propagation function. In this case, `__bwd_diff(myFunc)` will have the following signature: +The `bwd_diff` operator applies to a backward differentiable function and returns the backward propagation function. In this case, `bwd_diff(myFunc)` will have the following signature: ```csharp void myFunc_backProp(inout DifferentiablePair<float> a, inout DifferentiablePair<float> x, float dResult); @@ -85,7 +85,7 @@ The backward propagation function can be called as in the following code: var a = diffPair(2.0); // constructs DifferentialPair{2.0, 0.0} var x = diffPair(3.0); // constructs DifferentialPair{3.0, 0.0} -__bwd_diff(myFunc)(a, x, 1.0); +bwd_diff(myFunc)(a, x, 1.0); // a.d is now 9.0 // x.d is now 12.0 @@ -298,7 +298,7 @@ In this specific case, the automatically generated `IDifferentiable` implementat ## Forward Derivative Propagation Function -Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `__fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative propgation function. Likewise, the `__bwd_diff` operator can be used on a backward-differentiable function to obtain the backward derivative propagation function. This and the next sections cover the semantics of forward and backward propagation functions, and different ways to make a function forward and backward differentiable. +Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative propgation function. Likewise, the `bwd_diff` operator can be used on a backward-differentiable function to obtain the backward derivative propagation function. This and the next sections cover the semantics of forward and backward propagation functions, and different ways to make a function forward and backward differentiable. A forward derivative propagation function computes the derivative of the result value with regard to a specific set of input parameters. Given an original function, the signature of its forward propagation function is determined using the following rules: @@ -339,9 +339,9 @@ A function can be made forward-differentiable with a `[ForwardDifferentiable]` a R original(T0 p0, inout T1, p1, T2 p2); ``` -Once the function is made forward-differentiable, the forward propagation function can then be called with the `__fwd_diff` operator: +Once the function is made forward-differentiable, the forward propagation function can then be called with the `fwd_diff` operator: ```csharp -DifferentialPair<R> result = __fwd_diff(original)(...); +DifferentialPair<R> result = fwd_diff(original)(...); ``` ### User Defined Forward Derivative Functions @@ -406,7 +406,7 @@ struct T : IDifferentiable {...} struct R : IDifferentiable {...} struct ND {} // Non differentiable -[BackwardDifferentiable] +[Differentiable] R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5); ``` The signature of its backward propagation function is: @@ -423,16 +423,16 @@ Note that although `p2` is still `inout` in the backward propagation function, t ### Automatically Implemented Backward Propagation Functions -A function can be made backward-differentiable with a `[BackwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the backward propagation function. The syntax for using `[BackwardDifferentiable]` is: +A function can be made backward-differentiable with a `[Differentiable]` or `[BackwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the backward propagation function. The syntax for using `[Differentiable]` is: ```csharp -[BackwardDifferentiable] +[Differentiable] R original(T0 p0, inout T1, p1, T2 p2); ``` -Once the function is made backward-differentiable, the backward propagation function can then be called with the `__bwd_diff` operator: +Once the function is made backward-differentiable, the backward propagation function can then be called with the `bwd_diff` operator: ```csharp -__bwd_diff(original)(...); +bwd_diff(original)(...); ``` ### User Defined Backward Propagation Functions @@ -498,7 +498,7 @@ float myTerm(float x) float getSample(float a, float b) { ... } -[BackwardDifferentiable] +[Differentiable] float computeIntegralOverMyTerm(float x, float a, float b) { float sum = 0.0; @@ -530,7 +530,7 @@ Primal subsitute can be used as another way to make a function differentiable. A float myFunc(float x) {...} [PrimalSubstituteOf(myFunc)] -[BackwardDifferentiable] +[Differentiable] float myFuncSubst(float x) {...} // myFunc is now considered backward differentiable. @@ -548,8 +548,8 @@ float myFuncSubst(float x) { return x*x*x; } float caller(float x) { return myFunc(x); } let a = caller(4.0); // a == 16.0 (calling myFunc) -let b = __fwd_diff(caller)(diffPair(4.0, 1.0)).p; // b == 64.0 (calling myFuncSubst) -let c = __fwd_diff(caller)(diffPair(4.0, 1.0)).d; // c == 48.0 (calling derivative of myFuncSubst) +let b = fwd_diff(caller)(diffPair(4.0, 1.0)).p; // b == 64.0 (calling myFuncSubst) +let c = fwd_diff(caller)(diffPair(4.0, 1.0)).d; // c == 48.0 (calling derivative of myFuncSubst) ``` In case that a function has both custom defined derivatives and a differentiable primal substitute, the primal substitute overrides the custom defined derivative on the original function. All calls to the original function will be translated into calls to the primal substitute first, and differentiation step follows after. This means that the derivatives of the primal subsitute function will be used instead of the derivatives defined on the original function. @@ -639,7 +639,7 @@ float f(float x) return t.member; } ... -let result = __fwd_diff(f)(diffPair(3.0, 1.0)).d; // result == 0.0 +let result = fwd_diff(f)(diffPair(3.0, 1.0)).d; // result == 0.0 ``` In this case, we are assigning the value `x*x`, which carries a derivative, into a non-differentiable location `MyType.member`, thus throwing away any derivative info. When `f` returns `t.member`, there will be no derivative associated with it so the function will not propagate the derivative through. This code is most likely not intending to discard the derivative through the assignment. To help avoid this kind of unintentional behavior, Slang will treat any assignments of a value with @@ -674,7 +674,7 @@ float f(float x) return g(x) + x * x; } ``` -The derivative will not propagate through the call to `g` in `f`. As a result, `__fwd_diff(f)(diffPair(1.0, 1.0))` will return +The derivative will not propagate through the call to `g` in `f`. As a result, `fwd_diff(f)(diffPair(1.0, 1.0))` will return `{3.0, 2.0}` instead of `{3.0, 4.0}` as the derivative from `2*x` is lost through the non-differentiable call. To prevent unintended error, it is treated as a compile-time error to call `g` from `f`. If such a non-differentiable call is intended, a `no_diff` prefix is required in the call: ```csharp [ForwardDifferentiable] @@ -696,7 +696,7 @@ See the following code for an example of `[TreatAsDifferentiable]`: ```csharp interface IFoo { - [BackwardDifferentiable] + [Differentiable] float f(float v); } @@ -709,7 +709,7 @@ struct B : IFoo } } -[BackwardDifferentiable] +[Differentiable] float use(IFoo o, float x) { return o.f(x); @@ -717,20 +717,20 @@ float use(IFoo o, float x) // Test: B obj; -float result = __fwd_diff(use)(obj, diffPair(2.0, 1.0)).d; +float result = fwd_diff(use)(obj, diffPair(2.0, 1.0)).d; // result == 0.0, since `[TreatAsDifferentiable]` causes a trivial derivative implementation // being generated regardless of the original code. ``` ## Higher Order Differentiation -Slang supports generating higher order forward and backward derivative propagation functions. It is allowed to use `__fwd_diff` and `__bwd_diff` operators inside a forward or backward differentiable function, or to nest `__fwd_diff` and `__bwd_diff` operators. For example, `__fwd_diff(__fwd_diff(sin))` will have the following signature: +Slang supports generating higher order forward and backward derivative propagation functions. It is allowed to use `fwd_diff` and `bwd_diff` operators inside a forward or backward differentiable function, or to nest `fwd_diff` and `bwd_diff` operators. For example, `fwd_diff(fwd_diff(sin))` will have the following signature: ```csharp DifferentialPair<DifferentialPair<float>> sin_diff2(DifferentialPair<DifferentialPair<float>> x); ``` -The input parameter `x` contains four fields: `x.p.p`, `x.p.d,`, `x.d.p`, `x.d.d`, where `x.p.p` specifies the orgiginal input value, both `x.p.d` and `x.d.p` store the first order derivative if `x`, and `x.d.d` stores the second order derivative of `x`. Calling `__fwd_diff(__fwd_diff(sin))` with `diffPair(diffPair(pi/2, 1.0), DiffPair(1.0, 0.0))` will result `{ { 1.0, 0.0 }, { 0.0, -1.0 } }`. +The input parameter `x` contains four fields: `x.p.p`, `x.p.d,`, `x.d.p`, `x.d.d`, where `x.p.p` specifies the orgiginal input value, both `x.p.d` and `x.d.p` store the first order derivative if `x`, and `x.d.d` stores the second order derivative of `x`. Calling `fwd_diff(fwd_diff(sin))` with `diffPair(diffPair(pi/2, 1.0), DiffPair(1.0, 0.0))` will result `{ { 1.0, 0.0 }, { 0.0, -1.0 } }`. User defined higher-order derivative functions can be specified by using `[ForwardDerivative]` or `[BackwardDerivative]` attribute on the derivative function, or by using `[ForwardDerivativeOf]` or `[BackwardDerivativeOf]` attribute on the higher-order derivative function. @@ -738,7 +738,7 @@ User defined higher-order derivative functions can be specified by using `[Forwa Automatic differentiation for generic functions is supported. The forward-derivative and backward propagation functions of a generic function is also a generic function with the same set of generic parameters and constraints. Using `[ForwardDerivative]`, `[ForwardDerivativeOf]`, `[BackwardDerivative]` or `[BackwardDerivativeOf]` attributes to associate a derivative function with different set of generic parameters or constraints is a compile-time error. -An interface method requirement can be marked as `[ForwardDifferentiable]` or `[BackwardDifferentiable]` so they may be called in a forward or backward differentiable function and have the derivatives propagate through the call. This works regardless of whether or not the call can be specialized or has to go through dynamic dispatch. However, calls to interface methods are only differentiable once. Higher order differentiation through interface method calls are not supported. +An interface method requirement can be marked as `[ForwardDifferentiable]` or `[Differentiable]` so they may be called in a forward or backward differentiable function and have the derivatives propagate through the call. This works regardless of whether or not the call can be specialized or has to go through dynamic dispatch. However, calls to interface methods are only differentiable once. Higher order differentiation through interface method calls are not supported. ## Restrictions of Automatic Differentiation |
