summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-08 13:04:32 -0800
committerGitHub <noreply@github.com>2023-02-08 13:04:32 -0800
commit80b1b372dc131beefeda224ffa619b2b995173bd (patch)
treea205c494969bc5819ace77fd6891437fc24d2f4b
parentb1d7dc0707406da69d94f3915fa48f39020b93ab (diff)
Update autodiff documentation with more precise math definitions. (#2636)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--docs/_layouts/user-guide.html19
-rw-r--r--docs/user-guide/07-autodiff.md97
-rw-r--r--docs/user-guide/toc.html5
3 files changed, 93 insertions, 28 deletions
diff --git a/docs/_layouts/user-guide.html b/docs/_layouts/user-guide.html
index 1a4be7030..b9a387686 100644
--- a/docs/_layouts/user-guide.html
+++ b/docs/_layouts/user-guide.html
@@ -389,6 +389,25 @@
updateCurrentSubsection(findCurrentSubsection());
</script>
+ <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
+ <script type="text/x-mathjax-config">
+ MathJax.Hub.Config({
+ tex2jax: {
+ inlineMath: [ ['$$','$$'], ["\\(","\\)"] ],
+ displayMath: [ ['$$','$$'], ["\\(","\\)"] ],
+ },
+ TeX: {
+ Macros: {
+ bra: ["\\langle{#1}|", 1],
+ ket: ["|{#1}\\rangle", 1],
+ braket: ["\\langle{#1}\\rangle", 1],
+ bk: ["\\langle{#1}|{#2}|{#3}\\rangle", 3]
+ }
+ }
+ });
+ </script>
+ <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
+
{% if site.google_analytics %}
<script>
(function (i, s, o, g, r, a, m) {
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md
index 53fe61bbd..6ab4e6332 100644
--- a/docs/user-guide/07-autodiff.md
+++ b/docs/user-guide/07-autodiff.md
@@ -93,6 +93,48 @@ __bwd_diff(myFunc)(a, x, 1.0);
This completes the walkthrough of automatic differentiation features. The following sections will cover each perspective of the auto differentiation feature in more detail.
+## Mathematic Concepts and Terminologies
+
+This secions briefs reviews the mathematic theories behind differentiable programming with the intention to clarify the concepts and terminologies that will be used in the rest of this documentation. We assume the reader is already familiar with the basic theories behind neural network training, in particular the backpropagation algorithm.
+
+A differentiable system can be represented a composition of differentiable functions (kernels) with learnable parameters, where each differentiable function has the form:
+
+$$\mathbf{w}_{i+1} = f_i(\mathbf{w}_i) $$
+
+Where $$f_i$$ represents a differentiable function (kernel) in the system, $$\mathbf{w}$$ represents a collection of learnable parameters defined in function $$f_i$$, and $$\mathbf{w}_{i+1}$$ is the output of $$f_i$$. We will use $$\omega$$ to denote a specific parameter in $$\mathbf{w}$$.
+
+In a composed system, the value of $$\mathbf{w}$$ used to evaluate $$f_i$$ may come from an *upstream* function
+
+$$ \mathbf{w}_i = f_{i-1}(\mathbf{w}_{i-1}) $$
+
+Similarly, the value computed by $$f_i$$ may be used as argument to a *downstream* function
+
+$$ h = f_{i+1}(\mathbf{w}_{i+1}) = f_{i+1}(f_{i}(\mathbf{w}_{i}))$$
+
+The entire system composed from differentiable functions can be noted as
+
+$$Y = f_1 \circ f_2 \circ \cdots \circ f_n(\mathbf{w}_0)$$
+
+Where $$\mathbf{w}_0$$ is the first layer of parameters.
+
+### Forward Propagation of Derivatives
+When developing and training such a system, we are typically interested in evaluating the partial derivative of the system output with regard to some parameter $$\omega$$. To do so we can utilize the forward and backward derivative propagation functions for each $$f_i$$. Where the forward derivative propagation function is defined by:
+
+$$ \cal{F}[f_i] = f_i'(\mathbf{w}_i, \mathbf{w}_i') = \sum_{\omega_i\in\mathbf{w}_i} \frac{\partial f}{\partial \omega_i} \omega_i' $$
+
+Where $$\omega' \in \mathbf{w}'$$ represents the partial derivative of $$\omega_i$$ with regard to some upstream parameter $$\omega_{i-1}$$ that is used to compute $$\omega_i$$, i.e. $$\omega'=\frac{\partial \omega_{i}}{\partial \omega_{i-1}}$$.
+
+Given this definition, $$\cal{F}[f]$$ can be used as a forward propagation function that is able to compute $$\frac{\partial f_i}{\partial \omega_0}$$ from $$\frac{\partial \omega_{i-1}}{\partial \omega_0}$$.
+
+### Backward Propagation of Derivatives
+When training a neural network, we are more interested in figuring out the partial derivative of the final system output with regard to a parameter $$\omega_i$$ in $$f_i$$. To do so, we generally utilize the backward derivative propagation function
+
+$$\cal{B}[f_i] = f_i^{-1}(\frac{\partial Y}{\partial f_i}) = \frac{\partial Y}{\partial \mathbf{w}_i}$$
+
+Where the backward propagation function $$\cal{B}[f_i]$$ takes as input the partial derivative of the final system output $$Y$$ with regard to the output of $$f_i$$ (i.e. $$\mathbf{w}_i$$), and computes the partial derivative of the final system output with regard to the input of $$f_i$$ (i.e. $$\mathbf{w}_{i-1}$$).
+
+The higher order operator $$\cal{F}$$ and $$\cal{B}$$ represent the operations that converts an original or primal function $$f$$ to its forward or backward derivative propagation function. Slang's automatic differentiation feature provide builtin support for these operators to automatically generate the derivative propagation functions from a user defined primal function. The remaining documentation will discuss this feature from a programming language perspective.
+
## Differentiable Types
Slang will only generate differentiation code for values that has a *differentiable* type. A type is differentiable if it conforms to the builtin `IDifferentiable` interface. The definition of the `IDifferentiable` interface is:
```csharp
@@ -254,14 +296,14 @@ struct MyRayDifferential : IDifferentiable
In this specific case, the automatically generated `IDifferentiable` implementation will be exactly the same as the manually written code listed above.
-## Forward Derivative Function
+## Forward Derivative Propagation Function
-Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `__fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative of the function. Likewise, the `__bwd_diff` operator can be used on a backward-differentiable function to obtain the backward propagation function. This and the next sections cover the semantics of forward derivative and backward propagation functions, and different ways to make a function forward and backward differentiable.
+Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `__fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative propgation function. Likewise, the `__bwd_diff` operator can be used on a backward-differentiable function to obtain the backward derivative propagation function. This and the next sections cover the semantics of forward and backward propagation functions, and different ways to make a function forward and backward differentiable.
-A forward derivative function computes the derivative of the result value with regard to a specific set of input parameters.
-Given an original function, the signature of its forward derivative function is determined using the following rules:
-- If the return type `R` is differentiable, the forward derivative function will return `DifferentialPair<R>` that consists of both the computed original result value as well as the derivative of the result value. Otherwise, the return type is kept unmodified as `R`.
-- If a parameter has type `T` that is differentiable, it will be translated into a `DifferentialPair<T>` parameter in the derivative function to allow passing in the initial derivatives of each parameter in addition to the original values.
+A forward derivative propagation function computes the derivative of the result value with regard to a specific set of input parameters.
+Given an original function, the signature of its forward propagation function is determined using the following rules:
+- If the return type `R` is differentiable, the forward propagation function will return `DifferentialPair<R>` that consists of both the computed original result value as well as the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`.
+- If a parameter has type `T` that is differentiable, it will be translated into a `DifferentialPair<T>` parameter in the derivative function, where the differential component of the `DifferentialPair` holds the initial derivatives of each parameter with regard to their upstream parameters.
- All parameter directions are unchanged. For example, an `out` parameter in the original function will remain an `out` parameter in the derivative function.
For example, given original function:
@@ -269,10 +311,12 @@ For example, given original function:
R original(T0 p0, inout T1 p1, T2 p2);
```
Where `R`, `T0`, and `T1` is differentiable and `T2` is non-differentiable, the forward derivative function will have the following signature:
-```Swift
+```csharp
DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T1> p1, T2 p2);
```
+This forward propagation function takes the initial primal value of `p0` in `p0.p`, and the partial derivative of `p0` with regard to some upstream parameter in `p0.d`. It takes the initial primal and derivative values of `p1` and updates `p1` to hold the newly computed value and propagated derivative. Since `p2` is not differentiable, it remains unchanged.
+
`DifferentialPair<T>` is a builtin type that carrys both the original and derivative value of a term. It is defined as follows:
```csharp
struct DifferentialPair<T : IDifferentiable> : IDifferentiable
@@ -288,20 +332,20 @@ struct DifferentialPair<T : IDifferentiable> : IDifferentiable
### Automatic Implementation of Forward Derivative Functions
-A function can be made forward-differentiable with a `[ForwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the forward-derivative function. The syntax for using `[ForwardDifferentiable]` is:
+A function can be made forward-differentiable with a `[ForwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the forward propagation function. The syntax for using `[ForwardDifferentiable]` is:
```csharp
[ForwardDifferentiable]
R original(T0 p0, inout T1, p1, T2 p2);
```
-Once the function is made forward-differentiable, the forward derivative function can then be called with the `__fwd_diff` operator:
+Once the function is made forward-differentiable, the forward propagation function can then be called with the `__fwd_diff` operator:
```csharp
DifferentialPair<R> result = __fwd_diff(original)(...);
```
### User Defined Forward Derivative Functions
-As an alternative to compiler-implemented forward derivatives, the user can choose to manually provide an derivative implementation to make an existing function forward-differentiable. The `[ForwardDerivative(derivative_func)]` attribute is used to associate a function with its forward-derivative implementation. The syntax for using `[ForwardDerivative]` attribute is:
+As an alternative to compiler-implemented forward derivatives, the user can choose to manually provide an derivative implementation to make an existing function forward-differentiable. The `[ForwardDerivative(derivative_func)]` attribute is used to associate a function with its forward derivative propagation implementation. The syntax for using `[ForwardDerivative]` attribute is:
```csharp
DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T1> p1, T2 p2)
{
@@ -339,21 +383,22 @@ DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T
}
```
-## Backward Propagation Function
+## Backward Derivative Propagation Function
-A backward propagation function propagates the derivative of the function output to all the input parameters simultaneously.
-Given an orignal function, the signature of its backward propagation function is determined using the following rules:
-- A back-prop function always returns `void`.
-- A differentiable `in` parameter of type `T` will become an `inout DifferentialPair<T>` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The propagated derivative will be written to the derivative part of the differential pair after the back-prop function returns. The initial derivative value of the pair is ignored as input.
-- A differentiable `out` parameter of type `T` will become an `in T.Differential` parameter, carrying the result derivative of the return value to propagate back to other input parameters.
-- A differentiable `inout` parameter of type `T` will become an `inout DifferentialPair<T>` parameter, where the original value of the argument, along with the resulting derivative of the argument is passed as input to the back-prop function as the original and derivative part of the pair. The propagated derivative to this input parameter will be written back and replace the derivative part of the pair. The resulting primal value will *not* be written back to the parameter.
-- A differentiable return value of type `R` will become an additional `in R.Differential` parameter at the end of the back-prop function parameter list, carrying the result derivative of the return value to propagate back to the input parameters.
-- A non-differential return value of type `NDR` will be dropped.
-- A non-differential `in` parameter of type `ND` will remain unchanged in the back-prop function.
-- A non-differentiable `out` parameter of type `ND` will be removed from the parameter list of the back-prop function.
-- A non-differentiable `inout` parameter of type `ND` will become an `in ND` parameter.
+A backward derivative propagation function propagates the derivative of the function output to all the input parameters simultaneously.
+
+Given an orignal function `f`, the general rule for determining the signature of its backward propagation function is that a differentiable output `o` becomes an input parameter holding the partial derivative of a downstream output with regard to the this differentiable output, i.e. $$\partial y/\partial o\$$); an input differentiable parameter `i` in the original function will become an output in the backward propagation function, holding the propagated partial derivative $$partial y/\partial i$$; and any non-differentiable outputs are dropped from the backward propagation function. This means that the backward propagation function never returns any values computed in the original function.
-The general rule is that any differentiable output becomes an input derivative parameter, and any non-differentiable outputs are dropped from the back-prop function. This means that the back-prop function never returns any values computed in the original function.
+More specifically, the signature of its backward propagation function is determined using the following rules:
+- A backward propagation function always returns `void`.
+- A differentiable `in` parameter of type `T` will become an `inout DifferentialPair<T>` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The original value will not be overwritten by the backward propagation function. The propagated derivative will be written to the derivative part of the differential pair after the backward propagation function returns. The initial derivative value of the pair is ignored as input.
+- A differentiable `out` parameter of type `T` will become an `in T.Differential` parameter, carrying the partial derivative of some downstream term with regard to the return value.
+- A differentiable `inout` parameter of type `T` will become an `inout DifferentialPair<T>` parameter, where the original value of the argument, along with the downstream partial derivative with regard to the argument is passed as input to the backward propagation function as the original and derivative part of the pair. The propagated derivative with regard to this input parameter will be written back and replace the derivative part of the pair. The primal value part of the parameter will *not* be updated.
+- A differentiable return value of type `R` will become an additional `in R.Differential` parameter at the end of the backward propagation function parameter list, carrying the result derivative of a downstream term with regard to the return value of the original function.
+- A non-differentiable return value of type `NDR` will be dropped.
+- A non-differentiable `in` parameter of type `ND` will remain unchanged in the backward propagation function.
+- A non-differentiable `out` parameter of type `ND` will be removed from the parameter list of the backward propagation function.
+- A non-differentiable `inout` parameter of type `ND` will become an `in ND` parameter.
For example consider the following original function:
```csharp
@@ -397,7 +442,7 @@ The syntax for using `[BackwardDerivative]` attribute is:
```csharp
void back_prop(
inout DifferentialPair<T> p0,
- T.Differential p1,
+ T1.Differential p1,
inout DifferentialPair<T> p2,
ND p3,
ND p5,
@@ -418,7 +463,7 @@ R original(T0 p0, inout T1, p1, T2 p2);
[BackwardDerivativeOf(original)]
void back_prop(
inout DifferentialPair<T> p0,
- T.Differential p1,
+ T1.Differential p1,
inout DifferentialPair<T> p2,
ND p3,
ND p5,
@@ -443,7 +488,7 @@ The following builtin functions are backward differentiable and both their forwa
Sometimes we do not wish a parameter to be considered differentiable despite it has a differentiable type. We can use the `no_diff` modifier on the parameter to inform the compiler to treat the parameter as non-differentiable and skip generating differentiation code for the parameter. The syntax is:
```csharp
-// Only differentaite this function with regard to `x`.
+// Only differentiate this function with regard to `x`.
float myFunc(no_diff float a, float x);
```
diff --git a/docs/user-guide/toc.html b/docs/user-guide/toc.html
index b65d7f831..a956fa474 100644
--- a/docs/user-guide/toc.html
+++ b/docs/user-guide/toc.html
@@ -80,9 +80,10 @@
<li data-link="07-autodiff"><span>Automatic Differentiation</span>
<ul class="toc_list">
<li data-link="07-autodiff#using-automatic-differentiation-in-slang"><span>Using Automatic Differentiation in Slang</span></li>
+<li data-link="07-autodiff#mathematic-concepts-and-terminologies"><span>Mathematic Concepts and Terminologies</span></li>
<li data-link="07-autodiff#differentiable-types"><span>Differentiable Types</span></li>
-<li data-link="07-autodiff#forward-derivative-function"><span>Forward Derivative Function</span></li>
-<li data-link="07-autodiff#backward-propagation-function"><span>Backward Propagation Function</span></li>
+<li data-link="07-autodiff#forward-derivative-propagation-function"><span>Forward Derivative Propagation Function</span></li>
+<li data-link="07-autodiff#backward-derivative-propagation-function"><span>Backward Derivative Propagation Function</span></li>
<li data-link="07-autodiff#builtin-differentiable-functions"><span>Builtin Differentiable Functions</span></li>
<li data-link="07-autodiff#excluding-parameters-from-differentiation"><span>Excluding Parameters From Differentiation</span></li>
<li data-link="07-autodiff#calling-non-differentiable-functions-from-a-differentiable-function"><span>Calling Non-Differentiable Functions from a Differentiable Function</span></li>