diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-09 15:18:36 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-09 15:18:36 -0800 |
| commit | a611d4b20cdfab59efff1b7b7a980c6a8be40a30 (patch) | |
| tree | 8abded25b1a9d36eb4fae89fbf30895da86a12c3 /docs | |
| parent | 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (diff) | |
Update user guide on `[PrimalSubstitute]`
Diffstat (limited to 'docs')
| -rw-r--r-- | docs/user-guide/07-autodiff.md | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index 6b21ee28c..4a9bb92ab 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -487,6 +487,73 @@ The following builtin functions are backward differentiable and both their forwa - Matrix operations: `transpose`, `determinant` - Legacy blending and lighting intrinsics: `dst`, `lit` +## Primal Substitute Functions + +Sometimes it is desirable to replace a function with another when generating forward or backward derivative propagation code. For example, the following code shows a function that computes the integral of some term by sampling and we want to use a different sampling stragegy when computing the derivatives. +```csharp +float myTerm(float x) +{ + return someComplexComputation(x); +} + +float getSample(float a, float b) { ... } + +[BackwardDifferentiable] +float computeIntegralOverMyTerm(float x, float a, float b) +{ + float sum = 0.0; + for (int i = 0; i < SAMPLE_COUNT; i++) + { + let s = no_diff getSample(a, b); + let y = myTerm(s); + sum += y * ((b-a)/SAMPLE_COUNT); + } + return sum; +} +``` + +In this code, the `getSample` function returns a random sample in the range of `[a,b]`. Assume we have another sampling function `getSampleForDerivativeComputation(a,b)` that we wish to use instead in derivative computation, we can do so by marking it as a primal-substitute of `getSample`, as in the following code: +```csharp +[PrimalSubstituteOf(getSample)] +float getSampleForDerivativeComputation(float a, float b) +{ + ... +} +``` + +Here, the `[PrimalSubstituteOf(getSample)]` attributes marks the `getSampleForDerivativeComputation` function as the substitute for `getSample` in derivative propagation functions. When a function has a primal subsittute, the compiler will treat all calls to that function as if it is a call to the substiute function when generating derivative code. Note that this only applies to compiler generated derivative function and does not affect user provided derivative functions. If a user provided derivative function calls `getSample`, it will not be replaced by `getSampleForDerivativeComputation` by the compiler. + +Similar to `[ForwardDerivative]` and `[ForwardDerivativeOf]` attributes, The `[PrimalSubsitute(substFunc)]` attribute works the other way around: it specifies the primal substitute function of the function being marked. + +Primal subsitute can be used as another way to make a function differentiable. A function is considered differentiable if it has a primal subsitute that is differentiable. The following code illustrates this mechanism. +```csharp +float myFunc(float x) {...} + +[PrimalSubstituteOf(myFunc)] +[BackwardDifferentiable] +float myFuncSubst(float x) {...} + +// myFunc is now considered backward differentiable. +``` + +The following example shows in more detail on how primal subsitute affects derivative computation. +```csharp +float myFunc(float x) { return x*x; } + +[PrimalSubstituteOf(myFunc)] +[ForwardDifferentiable] +float myFuncSubst(float x) { return x*x*x; } + +[ForwardDifferentiable] +float caller(float x) { return myFunc(x); } + +let a = caller(4.0); // a == 16.0 (calling myFunc) +let b = __fwd_diff(caller)(diffPair(4.0, 1.0)).p; // b == 64.0 (calling myFuncSubst) +let c = __fwd_diff(caller)(diffPair(4.0, 1.0)).d; // c == 48.0 (calling derivative of myFuncSubst) +``` + +In case that a function has both custom defined derivatives and a differentiable primal substitute, the primal substitute overrides the custom defined derivative on the original function. All calls to the original function will be translated into calls to the primal substitute first, and differentiation step follows after. This means that the derivatives of the primal subsitute function will be used instead of the derivatives defined on the original function. + ## Excluding Parameters From Differentiation Sometimes we do not wish a parameter to be considered differentiable despite it has a differentiable type. We can use the `no_diff` modifier on the parameter to inform the compiler to treat the parameter as non-differentiable and skip generating differentiation code for the parameter. The syntax is: |
