summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-07-12 13:02:57 -0700
committerGitHub <noreply@github.com>2023-07-12 13:02:57 -0700
commit39b7df94b287b2115f41ca038d560102246d0696 (patch)
treed0066f0a9bc88ecdcc04f39a167b6c04c4b1d33a
parentd0901aa7933ac31b0bf7648a31ec5c13de864457 (diff)
Update autodiff documentation. (#2979)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--docs/user-guide/07-autodiff.md56
-rw-r--r--docs/user-guide/a1-02-slangpy.md30
2 files changed, 43 insertions, 43 deletions
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md
index dc19ba0da..b3be3fbca 100644
--- a/docs/user-guide/07-autodiff.md
+++ b/docs/user-guide/07-autodiff.md
@@ -30,9 +30,9 @@ float myFunc(float a, float x)
}
```
-This allows the function to be used in the `__fwd_diff` operator, which is an higher order operation that takes in a forward-differentiable function and returns the forward-derivative of the function.
+This allows the function to be used in the `fwd_diff` operator, which is an higher order operation that takes in a forward-differentiable function and returns the forward-derivative of the function.
-The expression `__fwd_diff(myFunc)` will have the following signature:
+The expression `fwd_diff(myFunc)` will have the following signature:
```csharp
DifferentialPair<float> myFunc_fwd_derivative(DifferentialPair<float> a, DifferentialPair<float> x);
```
@@ -44,7 +44,7 @@ To use this function to compute the derivative of `myFunc` with regard to `x`, t
float a = 2.0;
float x = 3.0;
// Compute derivative with regard to `x`:
-let result = __fwd_diff(myFunc)(diffPair(a, 0.0), diffPair(x, 1.0));
+let result = fwd_diff(myFunc)(diffPair(a, 0.0), diffPair(x, 1.0));
// Print the derivative.
printf("%f", result.d);
@@ -57,9 +57,9 @@ In the example code above, `diffPair()` is a builtin function to construct a val
The forward derivative function allows the user to compute the derivative of a function with regard to a specific combination of input parameters at a time. In many cases, we need to know how each parameter affects the output. Instead of calling the forward derivative function once for each parameter, it is more efficient to call the *backward propagation* function that propagate the derivative of outputs to each input parameter.
-To allow the compiler to generate the backward propagation function, we simply mark our function with the `[BackwardDifferentiable]` attribute:
+To allow the compiler to generate the backward propagation function, we simply mark our function with the `[Differentiable]` or `[BackwardDifferentiable]` attribute:
```csharp
-[BackwardDifferentiable]
+[Differentiable]
float myFunc(float a, float x)
{
return a * x * x;
@@ -67,10 +67,10 @@ float myFunc(float a, float x)
```
> #### Note:
-> When a function is marked as `[BackwardDifferentiable]`, it is implied that the function is also `[ForwardDifferentiable]` and can be used in the `__fwd_diff` operator.
+> When a function is marked as `[Differentiable]`, it is implied that the function is both `[ForwardDifferentiable]` and `[BackwardDifferentiable]` and can be used in the `fwd_diff` operator.
-The `__bwd_diff` operator applies to a backward differentiable function and returns the backward propagation function. In this case, `__bwd_diff(myFunc)` will have the following signature:
+The `bwd_diff` operator applies to a backward differentiable function and returns the backward propagation function. In this case, `bwd_diff(myFunc)` will have the following signature:
```csharp
void myFunc_backProp(inout DifferentiablePair<float> a, inout DifferentiablePair<float> x, float dResult);
@@ -85,7 +85,7 @@ The backward propagation function can be called as in the following code:
var a = diffPair(2.0); // constructs DifferentialPair{2.0, 0.0}
var x = diffPair(3.0); // constructs DifferentialPair{3.0, 0.0}
-__bwd_diff(myFunc)(a, x, 1.0);
+bwd_diff(myFunc)(a, x, 1.0);
// a.d is now 9.0
// x.d is now 12.0
@@ -298,7 +298,7 @@ In this specific case, the automatically generated `IDifferentiable` implementat
## Forward Derivative Propagation Function
-Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `__fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative propgation function. Likewise, the `__bwd_diff` operator can be used on a backward-differentiable function to obtain the backward derivative propagation function. This and the next sections cover the semantics of forward and backward propagation functions, and different ways to make a function forward and backward differentiable.
+Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative propgation function. Likewise, the `bwd_diff` operator can be used on a backward-differentiable function to obtain the backward derivative propagation function. This and the next sections cover the semantics of forward and backward propagation functions, and different ways to make a function forward and backward differentiable.
A forward derivative propagation function computes the derivative of the result value with regard to a specific set of input parameters.
Given an original function, the signature of its forward propagation function is determined using the following rules:
@@ -339,9 +339,9 @@ A function can be made forward-differentiable with a `[ForwardDifferentiable]` a
R original(T0 p0, inout T1, p1, T2 p2);
```
-Once the function is made forward-differentiable, the forward propagation function can then be called with the `__fwd_diff` operator:
+Once the function is made forward-differentiable, the forward propagation function can then be called with the `fwd_diff` operator:
```csharp
-DifferentialPair<R> result = __fwd_diff(original)(...);
+DifferentialPair<R> result = fwd_diff(original)(...);
```
### User Defined Forward Derivative Functions
@@ -406,7 +406,7 @@ struct T : IDifferentiable {...}
struct R : IDifferentiable {...}
struct ND {} // Non differentiable
-[BackwardDifferentiable]
+[Differentiable]
R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5);
```
The signature of its backward propagation function is:
@@ -423,16 +423,16 @@ Note that although `p2` is still `inout` in the backward propagation function, t
### Automatically Implemented Backward Propagation Functions
-A function can be made backward-differentiable with a `[BackwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the backward propagation function. The syntax for using `[BackwardDifferentiable]` is:
+A function can be made backward-differentiable with a `[Differentiable]` or `[BackwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the backward propagation function. The syntax for using `[Differentiable]` is:
```csharp
-[BackwardDifferentiable]
+[Differentiable]
R original(T0 p0, inout T1, p1, T2 p2);
```
-Once the function is made backward-differentiable, the backward propagation function can then be called with the `__bwd_diff` operator:
+Once the function is made backward-differentiable, the backward propagation function can then be called with the `bwd_diff` operator:
```csharp
-__bwd_diff(original)(...);
+bwd_diff(original)(...);
```
### User Defined Backward Propagation Functions
@@ -498,7 +498,7 @@ float myTerm(float x)
float getSample(float a, float b) { ... }
-[BackwardDifferentiable]
+[Differentiable]
float computeIntegralOverMyTerm(float x, float a, float b)
{
float sum = 0.0;
@@ -530,7 +530,7 @@ Primal subsitute can be used as another way to make a function differentiable. A
float myFunc(float x) {...}
[PrimalSubstituteOf(myFunc)]
-[BackwardDifferentiable]
+[Differentiable]
float myFuncSubst(float x) {...}
// myFunc is now considered backward differentiable.
@@ -548,8 +548,8 @@ float myFuncSubst(float x) { return x*x*x; }
float caller(float x) { return myFunc(x); }
let a = caller(4.0); // a == 16.0 (calling myFunc)
-let b = __fwd_diff(caller)(diffPair(4.0, 1.0)).p; // b == 64.0 (calling myFuncSubst)
-let c = __fwd_diff(caller)(diffPair(4.0, 1.0)).d; // c == 48.0 (calling derivative of myFuncSubst)
+let b = fwd_diff(caller)(diffPair(4.0, 1.0)).p; // b == 64.0 (calling myFuncSubst)
+let c = fwd_diff(caller)(diffPair(4.0, 1.0)).d; // c == 48.0 (calling derivative of myFuncSubst)
```
In case that a function has both custom defined derivatives and a differentiable primal substitute, the primal substitute overrides the custom defined derivative on the original function. All calls to the original function will be translated into calls to the primal substitute first, and differentiation step follows after. This means that the derivatives of the primal subsitute function will be used instead of the derivatives defined on the original function.
@@ -639,7 +639,7 @@ float f(float x)
return t.member;
}
...
-let result = __fwd_diff(f)(diffPair(3.0, 1.0)).d; // result == 0.0
+let result = fwd_diff(f)(diffPair(3.0, 1.0)).d; // result == 0.0
```
In this case, we are assigning the value `x*x`, which carries a derivative, into a non-differentiable location `MyType.member`, thus throwing away any derivative info. When `f` returns `t.member`, there will be no derivative associated with it so
the function will not propagate the derivative through. This code is most likely not intending to discard the derivative through the assignment. To help avoid this kind of unintentional behavior, Slang will treat any assignments of a value with
@@ -674,7 +674,7 @@ float f(float x)
return g(x) + x * x;
}
```
-The derivative will not propagate through the call to `g` in `f`. As a result, `__fwd_diff(f)(diffPair(1.0, 1.0))` will return
+The derivative will not propagate through the call to `g` in `f`. As a result, `fwd_diff(f)(diffPair(1.0, 1.0))` will return
`{3.0, 2.0}` instead of `{3.0, 4.0}` as the derivative from `2*x` is lost through the non-differentiable call. To prevent unintended error, it is treated as a compile-time error to call `g` from `f`. If such a non-differentiable call is intended, a `no_diff` prefix is required in the call:
```csharp
[ForwardDifferentiable]
@@ -696,7 +696,7 @@ See the following code for an example of `[TreatAsDifferentiable]`:
```csharp
interface IFoo
{
- [BackwardDifferentiable]
+ [Differentiable]
float f(float v);
}
@@ -709,7 +709,7 @@ struct B : IFoo
}
}
-[BackwardDifferentiable]
+[Differentiable]
float use(IFoo o, float x)
{
return o.f(x);
@@ -717,20 +717,20 @@ float use(IFoo o, float x)
// Test:
B obj;
-float result = __fwd_diff(use)(obj, diffPair(2.0, 1.0)).d;
+float result = fwd_diff(use)(obj, diffPair(2.0, 1.0)).d;
// result == 0.0, since `[TreatAsDifferentiable]` causes a trivial derivative implementation
// being generated regardless of the original code.
```
## Higher Order Differentiation
-Slang supports generating higher order forward and backward derivative propagation functions. It is allowed to use `__fwd_diff` and `__bwd_diff` operators inside a forward or backward differentiable function, or to nest `__fwd_diff` and `__bwd_diff` operators. For example, `__fwd_diff(__fwd_diff(sin))` will have the following signature:
+Slang supports generating higher order forward and backward derivative propagation functions. It is allowed to use `fwd_diff` and `bwd_diff` operators inside a forward or backward differentiable function, or to nest `fwd_diff` and `bwd_diff` operators. For example, `fwd_diff(fwd_diff(sin))` will have the following signature:
```csharp
DifferentialPair<DifferentialPair<float>> sin_diff2(DifferentialPair<DifferentialPair<float>> x);
```
-The input parameter `x` contains four fields: `x.p.p`, `x.p.d,`, `x.d.p`, `x.d.d`, where `x.p.p` specifies the orgiginal input value, both `x.p.d` and `x.d.p` store the first order derivative if `x`, and `x.d.d` stores the second order derivative of `x`. Calling `__fwd_diff(__fwd_diff(sin))` with `diffPair(diffPair(pi/2, 1.0), DiffPair(1.0, 0.0))` will result `{ { 1.0, 0.0 }, { 0.0, -1.0 } }`.
+The input parameter `x` contains four fields: `x.p.p`, `x.p.d,`, `x.d.p`, `x.d.d`, where `x.p.p` specifies the orgiginal input value, both `x.p.d` and `x.d.p` store the first order derivative if `x`, and `x.d.d` stores the second order derivative of `x`. Calling `fwd_diff(fwd_diff(sin))` with `diffPair(diffPair(pi/2, 1.0), DiffPair(1.0, 0.0))` will result `{ { 1.0, 0.0 }, { 0.0, -1.0 } }`.
User defined higher-order derivative functions can be specified by using `[ForwardDerivative]` or `[BackwardDerivative]` attribute on the derivative function, or by using `[ForwardDerivativeOf]` or `[BackwardDerivativeOf]` attribute on the higher-order derivative function.
@@ -738,7 +738,7 @@ User defined higher-order derivative functions can be specified by using `[Forwa
Automatic differentiation for generic functions is supported. The forward-derivative and backward propagation functions of a generic function is also a generic function with the same set of generic parameters and constraints. Using `[ForwardDerivative]`, `[ForwardDerivativeOf]`, `[BackwardDerivative]` or `[BackwardDerivativeOf]` attributes to associate a derivative function with different set of generic parameters or constraints is a compile-time error.
-An interface method requirement can be marked as `[ForwardDifferentiable]` or `[BackwardDifferentiable]` so they may be called in a forward or backward differentiable function and have the derivatives propagate through the call. This works regardless of whether or not the call can be specialized or has to go through dynamic dispatch. However, calls to interface methods are only differentiable once. Higher order differentiation through interface method calls are not supported.
+An interface method requirement can be marked as `[ForwardDifferentiable]` or `[Differentiable]` so they may be called in a forward or backward differentiable function and have the derivatives propagate through the call. This works regardless of whether or not the call can be specialized or has to go through dynamic dispatch. However, calls to interface methods are only differentiable once. Higher order differentiation through interface method calls are not supported.
## Restrictions of Automatic Differentiation
diff --git a/docs/user-guide/a1-02-slangpy.md b/docs/user-guide/a1-02-slangpy.md
index 880db7ca6..702bc0cf9 100644
--- a/docs/user-guide/a1-02-slangpy.md
+++ b/docs/user-guide/a1-02-slangpy.md
@@ -118,13 +118,13 @@ for `square`, and expose it to PyTorch as an autograd function.
First we need to tell Slang compiler that we need the `square` function to be considered a differentiable function so Slang compiler can generate a backward derivative propagation function for it:
```csharp
-[BackwardDifferentiable]
+[Differentiable]
float square(float x)
{
return x * x;
}
```
-This is done by simply adding a `[BackwardDifferentiable]` attribute to our `square`function.
+This is done by simply adding a `[Differentiable]` attribute to our `square`function.
With that, we can now define `square_bwd_kernel` that performs backward propagation as:
@@ -139,17 +139,17 @@ void square_bwd_kernel(TensorView<float> input, TensorView<float> grad_out, Tens
DifferentialPair<float> dpInput = diffPair(input[globalIdx.xy]);
var gradInElem = grad_out[globalIdx.xy];
- __bwd_diff(square)(dpInput, gradInElem);
+ bwd_diff(square)(dpInput, gradInElem);
grad_propagated[globalIdx.xy] = dpInput.d;
}
```
Note that the function follows the same structure of `square_fwd_kernel`, with the only difference being that
-instead of calling into `square` to compute the forward value for each tensor element, we are calling `__bwd_diff(square)`
+instead of calling into `square` to compute the forward value for each tensor element, we are calling `bwd_diff(square)`
that represents the automatically generated backward propagation function of `square`.
-`__bwd_diff(square)` will have the following signature:
+`bwd_diff(square)` will have the following signature:
```csharp
-void __bwd_diff_square(inout DifferentialPair<float> dpInput, float dOut);
+void bwd_diff_square(inout DifferentialPair<float> dpInput, float dOut);
```
Where the first parameter, `dpInput` represents a pair of original and derivative value for `input`, and the second parameter,
@@ -159,7 +159,7 @@ derivative will be stored in `dpInput.d`. For example:
```csharp
// construct a pair where the primal value is 3, and derivative value is 0.
var dp = diffPair(3.0);
-__bwd_diff(square)(dp, 1.0);
+bwd_diff(square)(dp, 1.0);
// dp.d is now 6.0
```
@@ -201,7 +201,7 @@ class MySquareFuncInSlang(torch.autograd.Function):
Now we can use the autograd function `MySquareFuncInSlang` in our python script:
```python
-x = torch.tensor([[3.0, 4.0],[0.0, 1.0]], requires_grad=True, device=cuda_device)
+x = torch.tensor([[3.0, 4.0],[0.0, 1.0]], requires_grad=True, device='cuda')
print(f"X = {x}")
y_pred = MySquareFuncInSlang.apply(x)
loss = y_pred.sum()
@@ -280,8 +280,8 @@ void boxFilter_fwd(TensorView<float> input, TensorView<float> output)
```
How do we define the backward derivative propagation kernel? Note that in this example, there
-isn't a function like `square` that we can just mark as `[BackwardDifferentiable]` and
-call `__bwd_diff(square)` to get back the derivative of an input parameter.
+isn't a function like `square` that we can just mark as `[Differentiable]` and
+call `bwd_diff(square)` to get back the derivative of an input parameter.
In this example, the input comes from multiple elements in a tensor. How do we propagate the
derivatives to those input elements?
@@ -305,7 +305,7 @@ Now we can replace all direct accesses to `input` with a call to `getInputElemen
`computeOutputPixel` can be implemented as following:
```csharp
-[BackwardDifferentiable]
+[Differentiable]
float computeOutputPixel(
TensorView<float> input,
TensorView<float> inputGradToPropagateTo,
@@ -345,7 +345,7 @@ float computeOutputPixel(
The main changes compared to our original version of `computeOutputPixel` are:
- Added a `inputGradToPropagateTo` parameter.
- Modified `input[x,y]` with a call to `getInputElement`.
-- Added a `[BackwardDifferentiable]` attribute to the function.
+- Added a `[Differentiable]` attribute to the function.
With that, we can define our backward kernel function:
@@ -362,11 +362,11 @@ void boxFilter_bwd(
if (pixelLoc.x >= width) return;
if (pixelLoc.y >= height) return;
- __bwd_diff(computeOutputPixel)(input, inputGradToPropagateTo, pixelLoc);
+ bwd_diff(computeOutputPixel)(input, inputGradToPropagateTo, pixelLoc);
}
```
-The kernel function simply calls `__bwd_diff(computeOutputPixel)` without taking any return values from the call
+The kernel function simply calls `bwd_diff(computeOutputPixel)` without taking any return values from the call
and without writing to any elements in the final `inputGradToPropagateTo` tensor. But when exactly does the proapgated
output get written to the output gradient tensor (`inputGradToPropagateTo`)?
@@ -387,7 +387,7 @@ void getInputElement_bwd(
Here, we are providing a custom defined backward propagation function for `getInputElement`.
In this function, we simply add `derivative` to the element in `inputGradToPropagateTo` tensor.
-When we call `__bwd_diff(computeOutputPixel)` in `boxFilter_bwd`, the Slang compiler will automatically
+When we call `bwd_diff(computeOutputPixel)` in `boxFilter_bwd`, the Slang compiler will automatically
differentiate all operations and function calls in `computeOutputPixel`. By wrapping the tensor element access
with `getInputElement` and by providing a custom backward propagation function of `getInputElement`, we are effectively
telling the compiler what to do when a derivative propagates to an input tensor element. Inside the body