diff options
| -rw-r--r-- | docs/user-guide/07-autodiff.md | 73 | ||||
| -rw-r--r-- | docs/user-guide/a1-02-slangpy.md | 28 |
2 files changed, 49 insertions, 52 deletions
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index b3be3fbca..1f6cb06af 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -4,7 +4,7 @@ layout: user-guide # Automatic Differentiation -Neural networks and other machine learning techniques are becoming an increasingly popular way to solve many difficult problems in modern visual computing systems. However, to take advantage of these techniques, developers often need to reimplement many existing system components in a differentiable form to allow computing the derivatives of a function, or to propagate the derivative of a result backwards to each parameter. Slang provides builtin auto differentiation features to support developers adding differentiability to their existing code with as little effort as possible. In this chapter, we provide an overview of the auto differentiation features, followed by a detailed description on the new syntax and rules. +Neural networks and other machine learning techniques are becoming an increasingly popular way to solve many difficult problems in modern visual computing systems. However, to take advantage of these techniques, developers often need to reimplement many existing system components in a differentiable form to allow computing the derivatives of a function, or to propagate the derivative of a result backwards to each parameter. Slang provides built-in auto differentiation features to support developers adding differentiability to their existing code with as little effort as possible. In this chapter, we provide an overview of the auto differentiation features, followed by a detailed description on the new syntax and rules. ## Using Automatic Differentiation in Slang @@ -30,14 +30,14 @@ float myFunc(float a, float x) } ``` -This allows the function to be used in the `fwd_diff` operator, which is an higher order operation that takes in a forward-differentiable function and returns the forward-derivative of the function. +This allows the function to be used in the `fwd_diff` operator, which is a higher order operation that takes in a forward-differentiable function and returns the forward-derivative of the function. The expression `fwd_diff(myFunc)` will have the following signature: ```csharp DifferentialPair<float> myFunc_fwd_derivative(DifferentialPair<float> a, DifferentialPair<float> x); ``` -Where `DifferentialPair<T>` is a builtin type that encodes both the primal(original) value and the derivative value of a term. +Where `DifferentialPair<T>` is a built-in type that encodes both the primal(original) value and the derivative value of a term. To use this function to compute the derivative of `myFunc` with regard to `x`, the user can call the forward-derivative function by supplying the derivative value of `x` with `1.0` and the derivative value of `a` with `0.0`, as in the following code: ```csharp @@ -51,7 +51,7 @@ printf("%f", result.d); // Output: 12.0 ``` -In the example code above, `diffPair()` is a builtin function to construct a value of `DifferentialPair<T>` with a primal value and a derivative value. The primal value and derivative value stored in a `DifferentialPair` can be accessed with the `.p` and a `.d` property. +In the example code above, `diffPair()` is a built-in function to construct a value of `DifferentialPair<T>` with a primal value and a derivative value. The primal value and derivative value stored in a `DifferentialPair` can be accessed with the `.p` and a `.d` property. ### Backward Propagation @@ -73,12 +73,12 @@ float myFunc(float a, float x) The `bwd_diff` operator applies to a backward differentiable function and returns the backward propagation function. In this case, `bwd_diff(myFunc)` will have the following signature: ```csharp -void myFunc_backProp(inout DifferentiablePair<float> a, inout DifferentiablePair<float> x, float dResult); +void myFunc_backProp(inout DifferentialPair<float> a, inout DifferentialPair<float> x, float dResult); ``` -Where `a` is an `inout DifferentiablePair` where the initial value of `a` is passed into the function as primal value (in the `.p` property), and the propagated derivative of `a` is returned via the `.d` property of the `DifferentialPair`. The same rules applies to `x`. +Where `a` is an `inout DifferentialPair` where the initial value of `a` is passed into the function as primal value (in the `.p` property), and the propagated derivative of `a` is returned via the `.d` property of the `DifferentialPair`. The same rules apply to `x`. -The additional `dResult` parameter is the derivative of the return value to be propagated to the input parameters. Note that in a backward propagation function, an input will become a `inout DifferentiablePair` where the `.d` property of the pair is intended for receiving the propagation result, and the return value will become an input parameter that represents the source of backward propagation. +The additional `dResult` parameter is the derivative of the return value to be propagated to the input parameters. Note that in a backward propagation function, an input will become a `inout DifferentialPair` where the `.d` property of the pair is intended for receiving the propagation result, and the return value will become an input parameter that represents the source of backward propagation. The backward propagation function can be called as in the following code: ```csharp @@ -95,7 +95,7 @@ This completes the walkthrough of automatic differentiation features. The follow ## Mathematic Concepts and Terminologies -This secions briefly reviews the mathematic theories behind differentiable programming with the intention to clarify the concepts and terminologies that will be used in the rest of this documentation. We assume the reader is already familiar with the basic theories behind neural network training, in particular the backpropagation algorithm. +This section briefly reviews the mathematic theories behind differentiable programming with the intention to clarify the concepts and terminologies that will be used in the rest of this documentation. We assume the reader is already familiar with the basic theories behind neural network training, in particular the back-propagation algorithm. A differentiable system can be represented a composition of differentiable functions (kernels) with learnable parameters, where each differentiable function has the form: @@ -118,7 +118,7 @@ $$Y = f_1 \circ f_2 \circ \cdots \circ f_n(\mathbf{w}_0)$$ Where $$\mathbf{w}_0$$ is the first layer of parameters. ### Forward Propagation of Derivatives -When developing and training such a system, we often need to evaluate the partial derivative of a differentiable function with regard to some parameter $$\omega$$. The simpliest way to obtain a partial derivative is to call a forward derivative propagation function, which is defined by: +When developing and training such a system, we often need to evaluate the partial derivative of a differentiable function with regard to some parameter $$\omega$$. The simplest way to obtain a partial derivative is to call a forward derivative propagation function, which is defined by: $$ \mathbb{F}[f_i] = f_i'(\mathbf{w}_i, \mathbf{w}_i') = \sum_{\omega_i\in\mathbf{w}_i} \frac{\partial f}{\partial \omega_i} \omega_i' $$ @@ -133,10 +133,10 @@ $$\mathbb{B}[f_i] = f_i^{-1}(\frac{\partial Y}{\partial f_i}) = \frac{\partial Y Where the backward propagation function $$\mathbb{B}[f_i]$$ takes as input the partial derivative of the final system output $$Y$$ with regard to the output of $$f_i$$ (i.e. $$\mathbf{w}_i$$), and computes the partial derivative of the final system output with regard to the input of $$f_i$$ (i.e. $$\mathbf{w}_{i-1}$$). -The higher order operator $$\mathbb{F}$$ and $$\mathbb{B}$$ represent the operations that converts an original or primal function $$f$$ to its forward or backward derivative propagation function. Slang's automatic differentiation feature provide builtin support for these operators to automatically generate the derivative propagation functions from a user defined primal function. The remaining documentation will discuss this feature from a programming language perspective. +The higher order operator $$\mathbb{F}$$ and $$\mathbb{B}$$ represent the operations that converts an original or primal function $$f$$ to its forward or backward derivative propagation function. Slang's automatic differentiation feature provide built-in support for these operators to automatically generate the derivative propagation functions from a user defined primal function. The remaining documentation will discuss this feature from a programming language perspective. ## Differentiable Types -Slang will only generate differentiation code for values that has a *differentiable* type. A type is differentiable if it conforms to the builtin `IDifferentiable` interface. The definition of the `IDifferentiable` interface is: +Slang will only generate differentiation code for values that has a *differentiable* type. A type is differentiable if it conforms to the built-in `IDifferentiable` interface. The definition of the `IDifferentiable` interface is: ```csharp interface IDifferentiable { @@ -150,12 +150,12 @@ interface IDifferentiable static Differential dmul(This, Differential); } ``` -As defined by the `IDifferentiable` interface, a differentiable type must have a `Differential` associated type that stores the derivative of the value. A further requirement is that the type of a second-order derivative must be the same `Differential` type. In another word, given a type `T`, `T.Differential` can be different from `T`, but `T.Differential.Differential` must equal to `T.Differential`. +As defined by the `IDifferentiable` interface, a differentiable type must have a `Differential` associated type that stores the derivative of the value. A further requirement is that the type of the second-order derivative must be the same `Differential` type. In another word, given a type `T`, `T.Differential` can be different from `T`, but `T.Differential.Differential` must equal to `T.Differential`. In addition, a differentiable type must define the `zero` value of its derivative, and how to add and multiply derivative values. ### Builtin Differentiable Types -The following builtin types are differentiable: +The following built-in types are differentiable: - Scalars: `float`, `double` and `half`. - Vector/Matrix: `vector` and `matrix` of `float`, `double` and `half` types. - Arrays: `T[n]` is differentiable if `T` is differentiable. @@ -256,7 +256,7 @@ struct MyRayDifferential : IDifferentiable float3 d_dir; } ``` -In this case, since all fields in `MyRayDifferential` are differentiable, and the `Differential` of each field is the same as the original type of each field (i.e. `float3.Differential == flaot3` as defined in builtin library), the compiler will automatically use the type itself as its own `Differential`, making `MyRayDifferential` suitable for use as `Differential` of `MyRay`. +In this case, since all fields in `MyRayDifferential` are differentiable, and the `Differential` of each field is the same as the original type of each field (i.e. `float3.Differential == float3` as defined in built-in library), the compiler will automatically use the type itself as its own `Differential`, making `MyRayDifferential` suitable for use as `Differential` of `MyRay`. We can also choose to manually implement `IDifferentiable` interface for `MyRayDifferential` as in the following code: @@ -298,11 +298,11 @@ In this specific case, the automatically generated `IDifferentiable` implementat ## Forward Derivative Propagation Function -Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative propgation function. Likewise, the `bwd_diff` operator can be used on a backward-differentiable function to obtain the backward derivative propagation function. This and the next sections cover the semantics of forward and backward propagation functions, and different ways to make a function forward and backward differentiable. +Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative propagation function. Likewise, the `bwd_diff` operator can be used on a backward-differentiable function to obtain the backward derivative propagation function. This and the next sections cover the semantics of forward and backward propagation functions, and different ways to make a function forward and backward differentiable. A forward derivative propagation function computes the derivative of the result value with regard to a specific set of input parameters. Given an original function, the signature of its forward propagation function is determined using the following rules: -- If the return type `R` is differentiable, the forward propagation function will return `DifferentialPair<R>` that consists of both the computed original result value as well as the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`. +- If the return type `R` is differentiable, the forward propagation function will return `DifferentialPair<R>` that consists of both the computed original result value and the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`. - If a parameter has type `T` that is differentiable, it will be translated into a `DifferentialPair<T>` parameter in the derivative function, where the differential component of the `DifferentialPair` holds the initial derivatives of each parameter with regard to their upstream parameters. - All parameter directions are unchanged. For example, an `out` parameter in the original function will remain an `out` parameter in the derivative function. @@ -317,7 +317,7 @@ DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T This forward propagation function takes the initial primal value of `p0` in `p0.p`, and the partial derivative of `p0` with regard to some upstream parameter in `p0.d`. It takes the initial primal and derivative values of `p1` and updates `p1` to hold the newly computed value and propagated derivative. Since `p2` is not differentiable, it remains unchanged. -`DifferentialPair<T>` is a builtin type that carrys both the original and derivative value of a term. It is defined as follows: +`DifferentialPair<T>` is a built-in type that carries both the original and derivative value of a term. It is defined as follows: ```csharp struct DifferentialPair<T : IDifferentiable> : IDifferentiable { @@ -345,7 +345,7 @@ DifferentialPair<R> result = fwd_diff(original)(...); ``` ### User Defined Forward Derivative Functions -As an alternative to compiler-implemented forward derivatives, the user can choose to manually provide an derivative implementation to make an existing function forward-differentiable. The `[ForwardDerivative(derivative_func)]` attribute is used to associate a function with its forward derivative propagation implementation. The syntax for using `[ForwardDerivative]` attribute is: +As an alternative to compiler-implemented forward derivatives, the user can choose to manually provide a derivative implementation to make an existing function forward-differentiable. The `[ForwardDerivative(derivative_func)]` attribute is used to associate a function with its forward derivative propagation implementation. The syntax for using `[ForwardDerivative]` attribute is: ```csharp DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T1> p1, T2 p2) { @@ -387,7 +387,7 @@ DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T A backward derivative propagation function propagates the derivative of the function output to all the input parameters simultaneously. -Given an orignal function `f`, the general rule for determining the signature of its backward propagation function is that a differentiable output `o` becomes an input parameter holding the partial derivative of a downstream output with regard to the this differentiable output, i.e. $$\partial y/\partial o$$); an input differentiable parameter `i` in the original function will become an output in the backward propagation function, holding the propagated partial derivative $$\partial y/\partial i$$; and any non-differentiable outputs are dropped from the backward propagation function. This means that the backward propagation function never returns any values computed in the original function. +Given an original function `f`, the general rule for determining the signature of its backward propagation function is that a differentiable output `o` becomes an input parameter holding the partial derivative of a downstream output with regard to the differentiable output, i.e. $$\partial y/\partial o$$); an input differentiable parameter `i` in the original function will become an output in the backward propagation function, holding the propagated partial derivative $$\partial y/\partial i$$; and any non-differentiable outputs are dropped from the backward propagation function. This means that the backward propagation function never returns any values computed in the original function. More specifically, the signature of its backward propagation function is determined using the following rules: - A backward propagation function always returns `void`. @@ -455,7 +455,7 @@ void back_prop( R original(T0 p0, inout T1, p1, T2 p2); ``` -Similarly the `[BackwardDerivativeOf]` attribute can be used on the back-prop function in case it is not convenient to modify the definition of the original function, or the back-prop function can't be made visible from the original function: +Similarly, the `[BackwardDerivativeOf]` attribute can be used on the back-prop function in case it is not convenient to modify the definition of the original function, or the back-prop function can't be made visible from the original function: ```csharp R original(T0 p0, inout T1, p1, T2 p2); @@ -475,9 +475,9 @@ void back_prop( ## Builtin Differentiable Functions -The following builtin functions are backward differentiable and both their forward-derivative and backward-propagation functions are already defined in the builtin library: +The following built-in functions are backward differentiable and both their forward-derivative and backward-propagation functions are already defined in the built-in library: -- Arithmetic functions: `abs`, `max`, `min`, `sqrt`, `rcp`, `rsqrt`, `fma`, `mad`, `fmod`, `frac`, `radians`, `degrees` +- Arithmetic functions: `abs`, `max`, `min`, `sqrt`, `rcp`, `rsqrt`, `fma`, `mad`, `fmod`, `frac`, `radians`, `degrees` - Interpolation and clamping functions: `lerp`, `smoothstep`, `clamp`, `saturate` - Trigonometric functions: `sin`, `cos`, `sincos`, `tan`, `asin`, `acos`, `atan`, `atan2` - Hyperbolic functions: `sinh`, `cosh`, `tanh` @@ -521,11 +521,11 @@ float getSampleForDerivativeComputation(float a, float b) } ``` -Here, the `[PrimalSubstituteOf(getSample)]` attributes marks the `getSampleForDerivativeComputation` function as the substitute for `getSample` in derivative propagation functions. When a function has a primal subsittute, the compiler will treat all calls to that function as if it is a call to the substiute function when generating derivative code. Note that this only applies to compiler generated derivative function and does not affect user provided derivative functions. If a user provided derivative function calls `getSample`, it will not be replaced by `getSampleForDerivativeComputation` by the compiler. +Here, the `[PrimalSubstituteOf(getSample)]` attributes marks the `getSampleForDerivativeComputation` function as the substitute for `getSample` in derivative propagation functions. When a function has a primal substitute, the compiler will treat all calls to that function as if it is a call to the substitute function when generating derivative code. Note that this only applies to compiler generated derivative function and does not affect user provided derivative functions. If a user provided derivative function calls `getSample`, it will not be replaced by `getSampleForDerivativeComputation` by the compiler. Similar to `[ForwardDerivative]` and `[ForwardDerivativeOf]` attributes, The `[PrimalSubsitute(substFunc)]` attribute works the other way around: it specifies the primal substitute function of the function being marked. -Primal subsitute can be used as another way to make a function differentiable. A function is considered differentiable if it has a primal subsitute that is differentiable. The following code illustrates this mechanism. +Primal substitute can be used as another way to make a function differentiable. A function is considered differentiable if it has a primal substitute that is differentiable. The following code illustrates this mechanism. ```csharp float myFunc(float x) {...} @@ -536,7 +536,7 @@ float myFuncSubst(float x) {...} // myFunc is now considered backward differentiable. ``` -The following example shows in more detail on how primal subsitute affects derivative computation. +The following example shows in more detail on how primal substitute affects derivative computation. ```csharp float myFunc(float x) { return x*x; } @@ -552,7 +552,7 @@ let b = fwd_diff(caller)(diffPair(4.0, 1.0)).p; // b == 64.0 (calling myFuncSubs let c = fwd_diff(caller)(diffPair(4.0, 1.0)).d; // c == 48.0 (calling derivative of myFuncSubst) ``` -In case that a function has both custom defined derivatives and a differentiable primal substitute, the primal substitute overrides the custom defined derivative on the original function. All calls to the original function will be translated into calls to the primal substitute first, and differentiation step follows after. This means that the derivatives of the primal subsitute function will be used instead of the derivatives defined on the original function. +In case that a function has both custom defined derivatives and a differentiable primal substitute, the primal substitute overrides the custom defined derivative on the original function. All calls to the original function will be translated into calls to the primal substitute first, and differentiation step follows after. This means that the derivatives of the primal substitute function will be used instead of the derivatives defined on the original function. ## Working with Mixed Differentiable and Non-Differentiable Code @@ -568,7 +568,7 @@ Sometimes we do not wish a parameter to be considered differentiable despite it float myFunc(no_diff float a, float x); ``` -The forward derivative and backward propgation functions of `myFunc` should have the following signature: +The forward derivative and backward propagation functions of `myFunc` should have the following signature: ```csharp DifferentialPair<float> fwd_derivative(float a, DifferentialPair<float> x); void back_prop(float a, inout DifferentialPair<float> x, float dResult); @@ -587,7 +587,7 @@ void back_prop(float a, inout DifferentialPair<float> x, float d_y); ### Excluding Struct Members from Differentiation -When using automatic `IDifferentiable` conformance synthesis for a `struct` type, Slang will by-default treat all struct members that have a differentiable type as differentiable, and thus include a correspondant field in the generated `Differential` type for the struct. +When using automatic `IDifferentiable` conformance synthesis for a `struct` type, Slang will by-default treat all struct members that have a differentiable type as differentiable, and thus include a corresponding field in the generated `Differential` type for the struct. For example, given the following definition ```csharp struct MyType : IDifferentiable @@ -596,7 +596,7 @@ struct MyType : IDifferentiable float2 member2; } ``` -Slang will genereate: +Slang will generate: ```csharp struct MyType.Differential : IDifferentiable { @@ -641,10 +641,8 @@ float f(float x) ... let result = fwd_diff(f)(diffPair(3.0, 1.0)).d; // result == 0.0 ``` -In this case, we are assigning the value `x*x`, which carries a derivative, into a non-differentiable location `MyType.member`, thus throwing away any derivative info. When `f` returns `t.member`, there will be no derivative associated with it so -the function will not propagate the derivative through. This code is most likely not intending to discard the derivative through the assignment. To help avoid this kind of unintentional behavior, Slang will treat any assignments of a value with -derivative info into a non-differentiable location as a compile-time error. To eliminate this error, the user should either make `t.member` differentiable, or to force the assignment by clarifying the intention to discard any derivatives using the builtin `detach` method. -The following code will compile and the derivatives will be discarded: +In this case, we are assigning the value `x*x`, which carries a derivative, into a non-differentiable location `MyType.member`, thus throwing away any derivative info. When `f` returns `t.member`, there will be no derivative associated with it, so the function will not propagate the derivative through. This code is most likely not intending to discard the derivative through the assignment. To help avoid this kind of unintentional behavior, Slang will treat any assignments of a value with derivative info into a non-differentiable location as a compile-time error. To eliminate this error, the user should either make `t.member` differentiable, or to force the assignment by clarifying the intention to discard any derivatives using the built-in `detach` method. +The following code will compile, and the derivatives will be discarded: ```csharp [ForwardDifferentiable] float f(float x) @@ -657,8 +655,7 @@ float f(float x) ``` ### Calling Non-Differentiable Functions from a Differentiable Function -Calling non-differentiable function from a differentiable function is allowed. However, derivatives will not be propagated through the call. The user is required to clarify the intention by prefixing the call with the `no_diff` keyword. An unclarified -call to non-differnetiable function will result in a compile-time error. +Calling non-differentiable function from a differentiable function is allowed. However, derivatives will not be propagated through the call. The user is required to clarify the intention by prefixing the call with the `no_diff` keyword. An un-clarified call to non-differentiable function will result in a compile-time error. For example, consider the following code: ```csharp @@ -685,7 +682,7 @@ float f(float x) } ``` -However, the `no_diff` keyword is not required in a call if a non-differentiable function does not take any differentiable parameters, or if the result of the differentiable function is not dependant on the derivative being propagated through the call. +However, the `no_diff` keyword is not required in a call if a non-differentiable function does not take any differentiable parameters, or if the result of the differentiable function is not dependent on the derivative being propagated through the call. ### Treat Non-Differentiable Functions as Differentiable Slang allows functions to be marked with a `[TreatAsDifferentiable]` attribute for them to be considered as differentiable functions by the type-system. When a function is marked as `[TreatAsDifferentiable]`, the compiler will not generate derivative propagation code from the original function body or perform any additional checking on the function definition. Instead, it will generate trivial forward and backward propagation functions that returns 0. @@ -730,7 +727,7 @@ Slang supports generating higher order forward and backward derivative propagati DifferentialPair<DifferentialPair<float>> sin_diff2(DifferentialPair<DifferentialPair<float>> x); ``` -The input parameter `x` contains four fields: `x.p.p`, `x.p.d,`, `x.d.p`, `x.d.d`, where `x.p.p` specifies the orgiginal input value, both `x.p.d` and `x.d.p` store the first order derivative if `x`, and `x.d.d` stores the second order derivative of `x`. Calling `fwd_diff(fwd_diff(sin))` with `diffPair(diffPair(pi/2, 1.0), DiffPair(1.0, 0.0))` will result `{ { 1.0, 0.0 }, { 0.0, -1.0 } }`. +The input parameter `x` contains four fields: `x.p.p`, `x.p.d,`, `x.d.p`, `x.d.d`, where `x.p.p` specifies the original input value, both `x.p.d` and `x.d.p` store the first order derivative if `x`, and `x.d.d` stores the second order derivative of `x`. Calling `fwd_diff(fwd_diff(sin))` with `diffPair(diffPair(pi/2, 1.0), DiffPair(1.0, 0.0))` will result `{ { 1.0, 0.0 }, { 0.0, -1.0 } }`. User defined higher-order derivative functions can be specified by using `[ForwardDerivative]` or `[BackwardDerivative]` attribute on the derivative function, or by using `[ForwardDerivativeOf]` or `[BackwardDerivativeOf]` attribute on the higher-order derivative function. @@ -738,7 +735,7 @@ User defined higher-order derivative functions can be specified by using `[Forwa Automatic differentiation for generic functions is supported. The forward-derivative and backward propagation functions of a generic function is also a generic function with the same set of generic parameters and constraints. Using `[ForwardDerivative]`, `[ForwardDerivativeOf]`, `[BackwardDerivative]` or `[BackwardDerivativeOf]` attributes to associate a derivative function with different set of generic parameters or constraints is a compile-time error. -An interface method requirement can be marked as `[ForwardDifferentiable]` or `[Differentiable]` so they may be called in a forward or backward differentiable function and have the derivatives propagate through the call. This works regardless of whether or not the call can be specialized or has to go through dynamic dispatch. However, calls to interface methods are only differentiable once. Higher order differentiation through interface method calls are not supported. +An interface method requirement can be marked as `[ForwardDifferentiable]` or `[Differentiable]`, so they may be called in a forward or backward differentiable function and have the derivatives propagate through the call. This works regardless of whether the call can be specialized or has to go through dynamic dispatch. However, calls to interface methods are only differentiable once. Higher order differentiation through interface method calls are not supported. ## Restrictions of Automatic Differentiation diff --git a/docs/user-guide/a1-02-slangpy.md b/docs/user-guide/a1-02-slangpy.md index 702bc0cf9..d1d884739 100644 --- a/docs/user-guide/a1-02-slangpy.md +++ b/docs/user-guide/a1-02-slangpy.md @@ -12,7 +12,7 @@ In addition, using a per-thread programming model also results in more optimized ## Getting Started with slangpy -In this tutorial, we will use a simple example to walkthrough the steps to use Slang in your PyTorch project. +In this tutorial, we will use a simple example to walk through the steps to use Slang in your PyTorch project. ### Writing a simple kernel function as a Slang module @@ -27,7 +27,7 @@ float square(float x) } ``` -This function is self explanatory. To use it in PyTorch, we need to write a GPU kernel function (that maps to a +This function is self-explanatory. To use it in PyTorch, we need to write a GPU kernel function (that maps to a `__global__` CUDA function) that defines how to compute each element of the input tensor. So we continue to write the following Slang function: @@ -62,7 +62,7 @@ TorchTensor<float> square_fwd(TorchTensor<float> input) return result; } ``` -Here, we mark the function with the `[TorchEntryPoint]` attribute so it will be exported to Python. In the function body, we call `TorchTensor<float>.zerosLike` to allocate a 2D-tensor that has the same size as the input. +Here, we mark the function with the `[TorchEntryPoint]` attribute, so it will be exported to Python. In the function body, we call `TorchTensor<float>.zerosLike` to allocate a 2D-tensor that has the same size as the input. `zerosLike` returns a `TorchTensor<float>` object that represents a CPU handle of a PyTorch tensor. Then we launch `square_fwd_kernel` with the `__dispatch_kernel` syntax. Note that we can directly pass `TorchTensor<float>` arguments to a `TensorView<float>` parameter and the compiler will automatically convert @@ -113,10 +113,10 @@ The above example demonstrates how to write a simple kernel function in Slang an Another major benefit of using Slang is that the Slang compiler support generating backward derivative propagation functions automatically. -In the following section, we walkthrough how to use Slang to generate a backward propagation function +In the following section, we walk through how to use Slang to generate a backward propagation function for `square`, and expose it to PyTorch as an autograd function. -First we need to tell Slang compiler that we need the `square` function to be considered a differentiable function so Slang compiler can generate a backward derivative propagation function for it: +First we need to tell Slang compiler that we need the `square` function to be considered a differentiable function, so Slang compiler can generate a backward derivative propagation function for it: ```csharp [Differentiable] float square(float x) @@ -153,7 +153,7 @@ void bwd_diff_square(inout DifferentialPair<float> dpInput, float dOut); ``` Where the first parameter, `dpInput` represents a pair of original and derivative value for `input`, and the second parameter, -`dOut`, represents the initial derivative with regard to some latent variable that we wish to backprop through. The resulting +`dOut`, represents the initial derivative with regard to some latent variable that we wish to back-prop through. The resulting derivative will be stored in `dpInput.d`. For example: ```csharp @@ -229,8 +229,8 @@ surrounding 3x3 pixel block. We can write a Slang function that computes the val ```csharp float computeOutputPixel(TensorView<float> input, uint2 pixelLoc) { - int width = input.dim(0); - int height = input.dim(1); + int width = input.size(0); + int height = input.size(1); // Track the sum of neighboring pixels and the number // of pixels currently accumulated. @@ -367,7 +367,7 @@ void boxFilter_bwd( ``` The kernel function simply calls `bwd_diff(computeOutputPixel)` without taking any return values from the call -and without writing to any elements in the final `inputGradToPropagateTo` tensor. But when exactly does the proapgated +and without writing to any elements in the final `inputGradToPropagateTo` tensor. But when exactly does the propagated output get written to the output gradient tensor (`inputGradToPropagateTo`)? And that logic is defined in our final piece of code: @@ -392,7 +392,7 @@ differentiate all operations and function calls in `computeOutputPixel`. By wrap with `getInputElement` and by providing a custom backward propagation function of `getInputElement`, we are effectively telling the compiler what to do when a derivative propagates to an input tensor element. Inside the body of `getInputElement_bwd`, we define what to do then: atomically adds the derivative propagated to the input element -in the `inputGradToPropagateTo` tensor. Therefore after running `boxFilter_bwd`, the `inputGradToPropagateTo` tensor will contain all the +in the `inputGradToPropagateTo` tensor. Therefore, after running `boxFilter_bwd`, the `inputGradToPropagateTo` tensor will contain all the back propagated derivative values. Again, to understand all the details of the automatic differentiation system, please refer to the @@ -402,9 +402,9 @@ Again, to understand all the details of the automatic differentiation system, pl As shown in previous tutorial, Slang has defined the `TorchTensor<T>` and `TensorView<T>` type for interop with PyTorch tensors. The `TorchTensor<T>` represents the CPU view of a tensor and provides methods to allocate a new tensor object. -The `TensorView<T>` represents the GPU view of a tensor and provides accesors to read write tensor data. +The `TensorView<T>` represents the GPU view of a tensor and provides accessors to read write tensor data. -Following is a list of builtin methods and attributes for PyTorch interop. +Following is a list of built-in methods and attributes for PyTorch interop. ### `TorchTensor` methods @@ -501,7 +501,7 @@ Marks a function as a CUDA kernel (maps to a `__global__` function) Marks a function for export to Python. Functions marked with `[TorchEntryPoint]` will be accessible from a loaded module returned by `slangpy.loadModule`. #### `[CudaDeviceExport]` attribute -Marks a function as a cuda device function, and ensures the compiler to include it in the generated cuda source. +Marks a function as a CUDA device function, and ensures the compiler to include it in the generated CUDA source. ## Type Marshalling Between Slang and Python @@ -528,4 +528,4 @@ Calling `myFunc` from python will result in a python tuple in the form of [[tensor, tensor, tensor], float] ``` -The same transform rules applies to parameter types. +The same transform rules apply to parameter types. |
