From ae778e3424b39cbeb1f367339f654560de416d30 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 30 Jan 2025 15:06:51 -0800 Subject: [Docs] Auto-diff documentation overhaul (#6202) * AD: Docs Update * More documentation * More documentation * More docs fixes * Cleanup documentation * More docs polish. Add docs for the [Differentiable] attributes * Fixup code sections * Fixup * Address review comments * regenerate documentation Table of Contents * Update docs with more playground links --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> Co-authored-by: Yong He --- docs/user-guide/07-autodiff.md | 918 ++++++++++++++++++++++------------------- docs/user-guide/toc.html | 17 +- source/slang/core.meta.slang | 269 +++++++++++- source/slang/diff.meta.slang | 157 ++++++- 4 files changed, 912 insertions(+), 449 deletions(-) diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index 11c6677c4..53d26fa6e 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -5,569 +5,578 @@ permalink: /user-guide/autodiff # Automatic Differentiation -Neural networks and other machine learning techniques are becoming an increasingly popular way to solve many difficult problems in modern visual computing systems. However, to take advantage of these techniques, developers often need to reimplement many existing system components in a differentiable form to allow computing the derivatives of a function, or to propagate the derivative of a result backwards to each parameter. Slang provides built-in auto differentiation features to support developers adding differentiability to their existing code with as little effort as possible. In this chapter, we provide an overview of the auto differentiation features, followed by a detailed description on the new syntax and rules. +To support differentiable graphics systems such as Gaussian splatters, neural radiance fields, differentiable path tracers, and more, +Slang provides first class support for differentiable programming. +An overiew: +- Slang supports the `fwd_diff` and `bwd_diff` operators that can generate the forward and backward-mode derivative propagation functions for any valid Slang function annotated with the `[Differentiable]` attribute. +- The `DifferentialPair` built-in generic type is used to pass derivatives associated with each function input. +- The `IDifferentiable`, and the experimental `IDifferentiablePtrType`, interfaces denote differentiable value and pointer types respectively, and allow finer control over how types behave under differentiation. +- Futher, Slang allows for user-defined derivative functions through the `[ForwardDerivative(custom_fn)]` and `[BackwardDerivative(custom_fn)]` +- All Slang features, such as control-flow, generics, interfaces, extensions, and more are compatible with automatic differentiation, though the bottom of this chapter documents some sharp edges & known issues. -## Using Automatic Differentiation in Slang +## Auto-diff operations `fwd_diff` and `bwd_diff` -In this section, we walk through the steps to compute forward-derivative from input, and backward propagate the derivative from output to input. +In Slang, `fwd_diff` and `bwd_diff` are higher-order functions used to transform Slang functions into their forward or backward derivative methods. To better understand what these methods do, here is a small refresher on differentiable calculus: +### Mathematical overview: Jacobian and its vector products +Forward and backward derivative methods are two different ways of computing a dot product with the Jacobian of a given function. +Parts of this overview are based on JAX's excellent auto-diff cookbook [here](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#how-it-s-made-two-foundational-autodiff-functions). The relevant [wikipedia article](https://en.wikipedia.org/wiki/Automatic_differentiation) is also a great resource for understanding auto-diff. + +The [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) (also called the total derivative) of a function $\mathbf{f}(\mathbf{x})$ is represented by $D\mathbf{f}(\mathbf{x})$. -### Forward Differentiation +For a general function with multiple scalar inputs and multiple scalar outputs, the Jacobian is a _matrix_ where $D\mathbf{f}_{ij}$ represents the [partial derivative](https://en.wikipedia.org/wiki/Partial_derivative) of the $i^{th}$ output element w.r.t the $j^{th}$ input element $\frac{\partial f_i}{\partial x_j}$ -Suppose the user has already written a function that computes some mathematic term: +As an example, consider a polynomial function +$$ f(x, y) = x^3 + x^2 - y $$ +Here, $f$ here has 1 output and 2 inputs. $Df$ is therefore the row matrix: +$$ Df(x, y) = [\frac{\partial f}{\partial x}, \frac{\partial f}{\partial y}] = [3x^2 + 2x, -1] $$ -```csharp -float myFunc(float a, float x) -{ - return a * x * x; -} -``` +Another, more complex example with a function that has multiple outputs (for clarity, denoted by $f_1$, $f_2$, etc..) +$$ \mathbf{f}(x, y) = \begin{bmatrix} f_0(x, y) & f_1(x, y) & f_2(x, y) \end{bmatrix} = \begin{bmatrix} x^3 & y^2x & y^3 \end{bmatrix} $$ +Here, $Df$ is a 3x2 matrix with each element containing a partial derivative: +$$ D\mathbf{f}(x, y) = \begin{bmatrix} +\partial f_0 / \partial x & \partial f_0 / \partial y \\ +\partial f_1 / \partial x & \partial f_1 / \partial y \\ +\partial f_2 / \partial x & \partial f_2 / \partial y +\end{bmatrix} = +\begin{bmatrix} +3x^2 & 0 \\ +y^2 & 2yx \\ +0 & 3y^2 +\end{bmatrix} $$ -The user can make this function *forward-differentiable* by adding a `[ForwardDerivative]` attribute: -```csharp -[ForwardDifferentiable] -float myFunc(float a, float x) -{ - return a * x * x; -} -``` +Computing full Jacobians is often unnecessary and expensive. Instead, auto-diff offers ways to compute _products_ of the Jacobian with a vector, which is a much faster operation. +There are two basic ways to compute this product: + 1. the Jacobian-vector product $\langle D\mathbf{f}(\mathbf{x}), \mathbf{v} \rangle$, also called forward-mode autodiff, and can be computed using `fwd_diff` operator in Slang, and + 2. the vector-Jacobian product $\langle \mathbf{v}^T, D\mathbf{f}(\mathbf{x}) \rangle$, also called reverse-mode autodiff, and can be computed using `bwd_diff` operator in Slang. From a linear algebra perspective, this is the transpose of the forward-mode operator. -This allows the function to be used in the `fwd_diff` operator, which is a higher order operation that takes in a forward-differentiable function and returns the forward-derivative of the function. +#### Propagating derivatives with forward-mode auto-diff +The products described above allow the _propagation_ of derivatives forward and backward through the function $f$ -The expression `fwd_diff(myFunc)` will have the following signature: -```csharp -DifferentialPair myFunc_fwd_derivative(DifferentialPair a, DifferentialPair x); -``` +The forward-mode derivative (Jacobian-vector product) can convert a derivative of the inputs to a derivative of the outputs. +For example, lets say inputs $\mathbf{x}$ depend on some scalar $\theta$, and $\frac{\partial \mathbf{x}}{\partial \theta}$ is a vector of partial derivatives describing that dependency. -Where `DifferentialPair` is a built-in type that encodes both the primal(original) value and the derivative value of a term. -To use this function to compute the derivative of `myFunc` with regard to `x`, the user can call the forward-derivative function by supplying the derivative value of `x` with `1.0` and the derivative value of `a` with `0.0`, as in the following code: +Invoking forward-mode auto-diff with $\mathbf{v} = \frac{\partial \mathbf{x}}{\partial \theta}$ converts this into a derivative of the outputs w.r.t the same scalar $\theta$. +This can be verified by expanding the Jacobian and applying the [chain rule](https://en.wikipedia.org/wiki/Chain_rule) of derivatives: +$$\langle D\mathbf{f}(\mathbf{x}), \frac{\partial \mathbf{x}}{\partial \theta} \rangle = \langle \begin{bmatrix} \frac{\partial f_0}{\partial x_0} & \frac{\partial f_0}{\partial x_1} & \cdots \\ \frac{\partial f_1}{\partial x_0} & \frac{\partial f_1}{\partial x_1} & \cdots \\ \cdots & \cdots & \cdots \end{bmatrix}, \begin{bmatrix} \frac{\partial x_0}{\partial \theta} \\ \frac{\partial x_1}{\partial \theta} \\ \cdots \end{bmatrix} \rangle = \begin{bmatrix} \frac{\partial f_0}{\partial \theta} \\ \frac{\partial f_1}{\partial \theta} \\ \cdots \end{bmatrix} = \frac{\partial \mathbf{f}}{\partial \theta}$$ -```csharp -float a = 2.0; -float x = 3.0; -// Compute derivative with regard to `x`: -let result = fwd_diff(myFunc)(diffPair(a, 0.0), diffPair(x, 1.0)); -// Print the derivative. -printf("%f", result.d); +#### Propagating derivatives with reverse-mode auto-diff +The reverse-mode derivative (vector-Jacobian product) can convert a derivative w.r.t outputs into a derivative w.r.t inputs. +For example, lets say we have some scalar $\mathcal{L}$ that depends on the outputs $\mathbf{f}$, and $\frac{\partial \mathcal{L}}{\partial \mathbf{f}}$ is a vector of partial derivatives describing that dependency. -// Output: 12.0 -``` +Invoking forward-mode auto-diff with $\mathbf{v} = \frac{\partial \mathcal{L}}{\partial \mathbf{f}}$ converts this into a derivative of the same scalar $\mathcal{L}$ w.r.t the inputs $\mathbf{x}$. +To provide more intuition for this, we can expand the Jacobian in a same way we did above: +$$\langle \frac{\partial \mathcal{L}}{\partial \mathbf{f}}^T, D\mathbf{f}(\mathbf{x}) \rangle = \langle \begin{bmatrix}\frac{\partial \mathcal{L}}{\partial f_0} & \frac{\partial \mathcal{L}}{\partial f_1} & \cdots \end{bmatrix}, \begin{bmatrix} \frac{\partial f_0}{\partial x_0} & \frac{\partial f_0}{\partial x_1} & \cdots \\ \frac{\partial f_1}{\partial x_0} & \frac{\partial f_1}{\partial x_1} & \cdots \\ \cdots & \cdots & \cdots \end{bmatrix} \rangle = \begin{bmatrix} \frac{\partial \mathcal{L}}{\partial x_0} & \frac{\partial \mathcal{L}}{\partial x_1} & \cdots \end{bmatrix} = \frac{\partial \mathcal{L}}{\partial \mathbf{x}}^T$$ -In the example code above, `diffPair()` is a built-in function to construct a value of `DifferentialPair` with a primal value and a derivative value. The primal value and derivative value stored in a `DifferentialPair` can be accessed with the `.p` and a `.d` property. +This mode is the most popular, since machine learning systems often construct their differentiable pipeline with multiple inputs (which can number in the millions or billions), and a single scalar output often referred to as the 'loss' denoted by $\mathcal{L}$. The desired derivative can be constructed with a single reverse-mode invocation. -### Backward Propagation +### Invoking auto-diff in Slang +With the mathematical foundations established, we can describe concretely how to compute derivatives using Slang. -The forward derivative function allows the user to compute the derivative of a function with regard to a specific combination of input parameters at a time. In many cases, we need to know how each parameter affects the output. Instead of calling the forward derivative function once for each parameter, it is more efficient to call the *backward propagation* function that propagate the derivative of outputs to each input parameter. +In Slang derivatives are computed using `fwd_diff`/`bwd_diff` which each correspond to Jacobian-vector and vector-Jacobian products. +For forward-diff, to pass the vector $\mathbf{v}$ and receive the outputs, we use the `DifferentialPair` type. We use pairs of inputs because every input element $x_i$ has a corresponding element $v_i$ in the vector, and each original output element has a corresponding output element in the product. -To allow the compiler to generate the backward propagation function, we simply mark our function with the `[Differentiable]` or `[BackwardDifferentiable]` attribute: +Example of `fwd_diff`: ```csharp -[Differentiable] -float myFunc(float a, float x) -{ - return a * x * x; +[Differentiable] // Auto-diff requires that functions are marked differentiable +float2 foo(float a, float b) +{ + return float2(a * b * b, a * a); } -``` -> #### Note: -> When a function is marked as `[Differentiable]`, it is implied that the function is both `[ForwardDifferentiable]` and `[BackwardDifferentiable]` and can be used in the `fwd_diff` operator. - - -The `bwd_diff` operator applies to a backward differentiable function and returns the backward propagation function. In this case, `bwd_diff(myFunc)` will have the following signature: - -```csharp -void myFunc_backProp(inout DifferentialPair a, inout DifferentialPair x, float dResult); -``` +void main() +{ + DifferentialPair dp_a = diffPair( + 1.0, // input 'a' + 1.0 // vector 'v' for vector-Jacobian product input (for 'a') + ); -Where `a` is an `inout DifferentialPair` where the initial value of `a` is passed into the function as primal value (in the `.p` property), and the propagated derivative of `a` is returned via the `.d` property of the `DifferentialPair`. The same rules apply to `x`. + DifferentialPair dp_b = diffPair(2.4, 0.0); -The additional `dResult` parameter is the derivative of the return value to be propagated to the input parameters. Note that in a backward propagation function, an input will become a `inout DifferentialPair` where the `.d` property of the pair is intended for receiving the propagation result, and the return value will become an input parameter that represents the source of backward propagation. + // fwd_diff to compute output and d_output w.r.t 'a'. + // Our output is also a differential pair. + // + DifferentialPair dp_output = fwd_diff(foo)(dp_a, dp_b); -The backward propagation function can be called as in the following code: -```csharp -var a = diffPair(2.0); // constructs DifferentialPair{2.0, 0.0} -var x = diffPair(3.0); // constructs DifferentialPair{3.0, 0.0} + // Extract output's primal part, which is just the standard output when foo is called normally. + // Can also use `.getPrimal()` + // + float2 output_p = dp_output.p; -bwd_diff(myFunc)(a, x, 1.0); + // Extract output's derivative part. Can also use `.getDifferential()` + float2 output_d = dp_output.d; -// a.d is now 9.0 -// x.d is now 12.0 + printf("foo(1.0, 2.4) = (%f %f)\n", output_p.x, output_p.y); + printf("d(foo)/d(a) at (1.0, 2.4) = (%f, %f)\n", output_d.x, output_d.y); +} ``` -This completes the walkthrough of automatic differentiation features. The following sections will cover each perspective of the auto differentiation feature in more detail. - -## Mathematic Concepts and Terminologies - -This section briefly reviews the mathematic theories behind differentiable programming with the intention to clarify the concepts and terminologies that will be used in the rest of this documentation. We assume the reader is already familiar with the basic theories behind neural network training, in particular the back-propagation algorithm. - -A differentiable system can be represented a composition of differentiable functions (kernels) with learnable parameters, where each differentiable function has the form: - -$$\mathbf{w}_{i+1} = f_i(\mathbf{w}_i) $$ - -Where $$f_i$$ represents a differentiable function (kernel) in the system, $$\mathbf{w}$$ represents a collection of learnable parameters defined in function $$f_i$$, and $$\mathbf{w}_{i+1}$$ is the output of $$f_i$$. We will use $$\omega$$ to denote a specific parameter in $$\mathbf{w}$$. - -In a composed system, the value of $$\mathbf{w}$$ used to evaluate $$f_i$$ may come from an *upstream* function +Note that all the inputs and outputs to our function become 'paired'. This only applies to differentiable types, such as `float`, `float2`, etc. See the section on differentiable types for more info. -$$ \mathbf{w}_i = f_{i-1}(\mathbf{w}_{i-1}) $$ +`diffPair(primal_val, diff_val)` is a built-in utility function that constructs the pair from the primal and differential values. -Similarly, the value computed by $$f_i$$ may be used as argument to a *downstream* function +Additionally, invoking forward-mode also computes the regular (or 'primal') output value (can be obtained from `output.getPrimal()` or `output.p`). The same is _not_ true for reverse-mode. -$$ h = f_{i+1}(\mathbf{w}_{i+1}) = f_{i+1}(f_{i}(\mathbf{w}_{i}))$$ +For reverse-mode, the example proceeds in a similar way, and we still use `DifferentialPair` type. However, note that each input gets a corresponding _output_ and each output gets a corresponding _input_. Thus, all inputs become `inout` differential pairs, to allow the function to write into the derivative part (the primal part is still accepted as an input in the same pair data-structure). +The one extra rule is that the derivative corresponding to the return value of the function is accepted as the last argument (an extra input). This value does not need to be a pair. -The entire system composed from differentiable functions can be noted as - -$$Y = f_1 \circ f_2 \circ \cdots \circ f_n(\mathbf{w}_0)$$ - -Where $$\mathbf{w}_0$$ is the first layer of parameters. - -### Forward Propagation of Derivatives -When developing and training such a system, we often need to evaluate the partial derivative of a differentiable function with regard to some parameter $$\omega$$. The simplest way to obtain a partial derivative is to call a forward derivative propagation function, which is defined by: +Example: +```csharp +[Differentiable] // Auto-diff requires that functions are marked differentiable +float2 foo(float a, float b) +{ + return float2(a * b * b, a * a); +} -$$ \mathbb{F}[f_i] = f_i'(\mathbf{w}_i, \mathbf{w}_i') = \sum_{\omega_i\in\mathbf{w}_i} \frac{\partial f}{\partial \omega_i} \omega_i' $$ +void main() +{ + DifferentialPair dp_a = diffPair( + 1.0 // input 'a' + ); // Calling diffPair without a derivative part initializes to 0. -Where $$\omega' \in \mathbf{w}'$$ represents the partial derivative of $$\omega_i$$ with regard to some upstream parameter $$\omega_{i-1}$$ that is used to compute $$\omega_i$$, i.e. $$\omega'=\frac{\partial \omega_{i}}{\partial \omega_{i-1}}$$. + DifferentialPair dp_b = diffPair(2.4); -Given this definition, $$\mathbb{F}[f]$$ can be used as a forward propagation function that is able to compute $$\frac{\partial f_i}{\partial \omega_0}$$ from $$\frac{\partial \omega_{i-1}}{\partial \omega_0}$$. + // Derivatives of scalar L w.r.t output. + float2 dL_doutput = float2(1.0, 0.0); -### Backward Propagation of Derivatives -When using the backpropagation algorithm to train a neural network, we are more interested in figuring out the partial derivative of the final system output with regard to a parameter $$\omega_i$$ in $$f_i$$. To do so, we generally utilize the backward derivative propagation function + // bwd_diff to compute dL_da and dL_db + // The derivative of the output is provided as an additional _input_ to the call + // Derivatives w.r.t inputs are written into dp_a.d and dp_b.d + // + bwd_diff(foo)(dp_a, dp_b, dL_doutput); -$$\mathbb{B}[f_i] = f_i^{-1}(\frac{\partial Y}{\partial f_i}) = \frac{\partial Y}{\partial \mathbf{w}_i}$$ + // Extract the derivatives of L w.r.t input + float dL_da = dp_a.d; + float dL_db = dp_b.d; -Where the backward propagation function $$\mathbb{B}[f_i]$$ takes as input the partial derivative of the final system output $$Y$$ with regard to the output of $$f_i$$ (i.e. $$\mathbf{w}_i$$), and computes the partial derivative of the final system output with regard to the input of $$f_i$$ (i.e. $$\mathbf{w}_{i-1}$$). + printf("If dL/dOutput = (1.0, 0.0), then (dL/da, dL/db) at (1.0, 2.4) = (%f, %f)", dL_da, dL_db); +} +``` -The higher order operator $$\mathbb{F}$$ and $$\mathbb{B}$$ represent the operations that converts an original or primal function $$f$$ to its forward or backward derivative propagation function. Slang's automatic differentiation feature provide built-in support for these operators to automatically generate the derivative propagation functions from a user defined primal function. The remaining documentation will discuss this feature from a programming language perspective. +## Differentiable Type System -## Differentiable Value Types Slang will only generate differentiation code for values that has a *differentiable* type. Differentiable types are defining through conformance to one of two built-in interfaces: 1. `IDifferentiable`: For value types (e.g. `float`, structs of value types, etc..) 2. `IDifferentiablePtrType`: For buffer, pointer & reference types that represent locations rather than values. -The `IDifferentiable` interface requires the following definitions (which can be auto-generated by the compiler for most scenarios) -```csharp -interface IDifferentiable -{ - associatedtype Differential : IDifferentiable - where Differential.Differential == Differential; - - static Differential dzero(); - - static Differential dadd(Differential, Differential); -} -``` -As defined by the `IDifferentiable` interface, a differentiable type must have a `Differential` associated type that stores the derivative of the value. A further requirement is that the type of the second-order derivative must be the same `Differential` type. In another word, given a type `T`, `T.Differential` can be different from `T`, but `T.Differential.Differential` must equal to `T.Differential`. +### Differentiable Value Types +All basic types (`float`, `int`, `double`, etc..) and all aggregate types (i.e. `struct`) that use any combination of these are considered value types in Slang. -In addition, a differentiable type must define the `zero` value of its derivative, and how to add two derivative values together. These function are used during reverse-mode auto-diff, to initialize and accumulate derivatives of the given type. - -By contrast, `IDifferentiablePtrType` only requires a `Differential` associated type which also conforms to `IDifferentiablePtrType`. -```csharp -interface IDifferentiablePtrType -{ - associatedtype Differential : IDifferentiablePtrType; - where Differential.Differential == Differential; -} -``` +Slang uses the `IDifferentiable` interface to define differentiable types. Basic types that describe a continuous value (`float`, `double` and `half`) and their vector/matrix versions (`float3`, `half2x2`, etc..) are defined as differentiable by the standard library. For all basic types, the type used for the differential (can be obtained with `T.Differential`) is the same as the primal. -> #### Note #### -> Support for `IDifferentiablePtrType` is still experimental. - -Types should not conform to both `IDifferentiablePtrType` and `IDifferentiable`. Such cases will result in a compiler error. - - -### Builtin Differentiable Value Types +#### Builtin Differentiable Value Types The following built-in types are differentiable: - Scalars: `float`, `double` and `half`. - Vector/Matrix: `vector` and `matrix` of `float`, `double` and `half` types. - Arrays: `T[n]` is differentiable if `T` is differentiable. - Tuples: `Tuple` is differentiable if `T` is differentiable. -### Builtin Differentiable Ptr Types -There are currently no built-in types that conform to `IDifferentiablePtrType` - -### User Defined Differentiable Types -The user can make any `struct` types differentiable by implementing either `IDifferentiable` & `IDifferentiablePtrType` interface on the type. -The requirements from `IDifferentiable` interface can be fulfilled automatically or manually, though `IDifferentiablePtrType` currently requires the user to provide the `Differential` type. +#### User-defined Differentiable Value Types -#### Automatic Fulfillment of `IDifferentiable` Requirements -Assume the user has defined the following type: +However, it is easy to define your own differentiable types. +Typically, all you need is to implement the `IDifferentiable` interface. ```csharp -struct MyRay +struct MyType : IDifferentiable { - float3 origin; - float3 dir; - int nonDifferentiablePayload; -} + float x; + float y; +}; ``` -The type can be made differentiable by adding `IDifferentiable` conformance: +The main requirement of a type implementing `IDifferentiable` is the `Differential` associated type that the compiler uses to carry the corresponding derivative. +In most cases the `Differential` of a type can be itself, though it can be different if necessary. +You can access the differential of any differentiable type through `Type.Differential` + +Example: ```csharp -struct MyRay : IDifferentiable +MyType obj; +obj.x = 1.f; + +MyType.Differential d_obj; +// Differentiable fields will have a corresponding field in the diff type +d_obj.x = 1.f; +``` + +Slang can automatically derive the `Differential` type in the majority of cases. +For instance, for `MyType`, Slang can infer the differential trivially: +```csharp +struct MyType : IDifferentiable { - float3 origin; - float3 dir; - int nonDifferentiablePayload; + // Automatically inserted by Slang from the fact that + // MyType has 2 floats which are both differentiable + // + typealias Differential = MyType; + // ... } ``` -Note that this code does not provide any explicit implementation of the `IDifferentiable` requirements. In this case the compiler will automatically synthesize all the requirements. This should provide the desired behavior most of the time. The procedure for synthesizing the interface implementation is as follows: -1. A new type is generated that stores the `Differential` of all differentiable fields. This new type itself will conform to the `IDifferentiable` interface, and it will be used to satisfy the `Differential` associated type requirement. -2. Each differential field will be associated to its corresponding field in the newly synthesized `Differential` type. -3. The `zero` value of the differential type is made from the `zero` value of each field in the differential type. -4. The `dadd` method invokes the `dadd` operations for each field whose type conforms to `IDifferentiable`. -5. If the synthesized `Differential` type contains exactly the same fields as the original type, and the type of each field is the same as the original field type, then the original type itself will be used as the `Differential` type instead of creating a new type to satisfy the `Differential` associated type requirement. This means that all the synthesized `Differential` type use itself to meet its own `IDifferentiable` requirements. - -#### Manual Fulfillment of `IDifferentiable` Requirements - -In rare cases where more control is desired, the user can manually provide the implementation. To do so, we will first define the `Differential` type for `MyRay`, and use it to fulfill the `Differential` requirement in `MyRay`: +For more complex types that aren't fully differentiable, a new type is synthesized automatically: ```csharp -struct MyRayDifferential +struct MyPartialDiffType : IDifferentiable { - float3 d_origin; - float3 d_dir; -} + // Automatically inserted by Slang based on which fields are differentiable. + typealias MyPartialDiffType = syn_MyPartialDiffType_Differential; + + float x; + uint y; +}; -struct MyRay : IDifferentiable +// Synthesized +struct syn_MyPartialDiffType_Differential { - // Specify that `MyRay.Differential` is `MyRayDifferential`. - typealias Differential = MyRayDifferential; - - // Specify that the derivative for `origin` will be stored in `MayRayDifferential.d_origin`. - [DerivativeMember(MayRayDifferential.d_origin)] - float3 origin; - - // Specify that the derivative for `dir` will be stored in `MayRayDifferential.d_dir`. - [DerivativeMember(MayRayDifferential.d_dir)] - float3 dir; + // Only one field since 'y' does not conform to IDifferentiable + float x; +}; +``` - // This is a non-differentiable field so we don't put any attributes on it. - int nonDifferentiablePayload; +You can make existing types differentiable through Slang's extension mechanism. +For instance, `extension MyType : IDifferentiable { }` will make `MyType` differentiable retroactively. - // Define zero derivative. - static MyRayDifferential dzero() - { - return {float3(0.0), float3(0.0)}; - } +See the `IDifferentiable` [reference documentation](https://shader-slang.org/stdlib-reference/interfaces/idifferentiable-01/index) for more information on how to override the default behavior. - // Define the add operation of two derivatives. - static MyRayDifferential dadd(MyRayDifferential v1, MyRayDifferential v2) - { - MyRayDifferential result; - result.d_origin = v1.d_origin + v2.d_origin; - result.d_dir = v1.d_dir + v2.d_dir; - return result; - } -} -``` +#### DifferentialPair: Pairs of differentiable value types -Note that for each struct field that is differentiable, we need to use the `[DerivativeMember]` attribute to associate it with the corresponding field in the `Differential` type, so the compiler knows how to access the derivative for the field. +The `DifferentialPair` type is used to pass derivatives to a derivative call by representing a pair of values of type `T` and `T.Differential`. Note that `T` must conform to `IDifferentiable`. -However, there is still a missing piece in the above code: we also need to make `MyRayDifferential` conform to `IDifferentiable` because it is required that the `Differential` of a type must itself be `Differential`. Again we can use automatic fulfillment by simply adding `IDifferentiable` conformance to `MyRayDifferential`: -```csharp -struct MyRayDifferential : IDifferentiable -{ - float3 d_origin; - float3 d_dir; -} -``` -In this case, since all fields in `MyRayDifferential` are differentiable, and the `Differential` of each field is the same as the original type of each field (i.e. `float3.Differential == float3` as defined in the core module), the compiler will automatically use the type itself as its own `Differential`, making `MyRayDifferential` suitable for use as `Differential` of `MyRay`. +`DifferentialPair` can either be created via constructor calls or the `diffPair` utility method. -We can also choose to manually implement `IDifferentiable` interface for `MyRayDifferential` as in the following code: +Example: ```csharp -struct MyRayDifferential : IDifferentiable -{ - typealias Differential = MyRayDifferential; +MyType obj = {1.f, 2.f}; - [DerivativeMember(MyRayDifferential.d_origin)] - float3 d_origin; +MyType.Differential d_obj = {0.4f, 3.f}; - [DerivativeMember(MyRayDifferential.d_dir)] - float3 d_dir; +// The differential part of a differentiable-pair is of the diff type. +DifferentialPair dp_obj = diffPair(obj, d_obj); - static MyRayDifferential dzero() - { - return {float3(0.0), float3(0.0)}; - } +// Use .p to extract the primal part +MyType new_p_obj = dp_obj.p; - static MyRayDifferential dadd(MyRayDifferential v1, MyRayDifferential v2) - { - MyRayDifferential result; - result.d_origin = v1.d_origin + v2.d_origin; - result.d_dir = v1.d_dir + v2.d_dir; - return result; - } -} +// Use .d to extract the differential part +MyType.Differential new_d_obj = dp_obj.d; ``` -In this specific case, the automatically generated `IDifferentiable` implementation will be exactly the same as the manually written code listed above. +### Differentiable Ptr types +Pointer types are any type that represents a location or reference to a value rather than the value itself. +Examples include resource types (`RWStructuredBuffer`, `Texture2D`), pointer types (`Ptr`) and references. -## Forward Derivative Propagation Function +The `IDifferentiablePtrType` interface can be used to denote types that need to transform into pairs during auto-diff. However, unlike +an `IDifferentiable` type whose derivative portion is an _output_ under `bwd_diff`, the derivative part of `IDifferentiablePtrType` remains an input. This is because only the value is returned as an output, while the location where it needs to be written to, is still effectively an input to the derivative methods. -Functions in Slang can be marked as forward-differentiable or backward-differentiable. The `fwd_diff` operator can be used on a forward-differentiable function to obtain the forward derivative propagation function. Likewise, the `bwd_diff` operator can be used on a backward-differentiable function to obtain the backward derivative propagation function. This and the next sections cover the semantics of forward and backward propagation functions, and different ways to make a function forward and backward differentiable. +> #### Note #### +> Support for `IDifferentiablePtrType` is still experimental. There are no built-in types conforming to this interface, though we plan to add stdlib support in the near future. -A forward derivative propagation function computes the derivative of the result value with regard to a specific set of input parameters. -Given an original function, the signature of its forward propagation function is determined using the following rules: -- If the return type `R` implements `IDifferentiable` the forward propagation function will return a corresponding `DifferentialPair` that consists of both the computed original result value and the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`. -- If a parameter has type `T` that implements `IDifferentiable`, it will be translated into a `DifferentialPair` parameter in the derivative function, where the differential component of the `DifferentialPair` holds the initial derivatives of each parameter with regard to their upstream parameters. -- If a parameter has type `T` that implements `IDifferentiablePtrType`, it will be translated into a `DifferentialPtrPair` parameter where the differential component references the differential location or buffer. -- All parameter directions are unchanged. For example, an `out` parameter in the original function will remain an `out` parameter in the derivative function. -- Differentiable methods cannot have a type implementing `IDifferentiablePtrType` as an `out` or `inout` parameter, or a return type. Types implementing `IDifferentiablePtrType` can only be used for input parameters to a differentiable method. Marking such a method as `[Differentiable]` will result in a compile-time diagnostic error. +`IDifferentiablePtrType` only requires a `Differential` associated type to be specified. -For example, given original function: -```csharp -R original(T0 p0, inout T1 p1, T2 p2, T3 p3); -``` -Where `R`, `T0`, `T1 : IDifferentiable`, `T2` is non-differentiable, and `T3 : IDifferentiablePtrType`, the forward derivative function will have the following signature: +#### DifferentialPtrPair: Pairs of differentiable ptr types +For types conforming to `IDifferentiablePtrType`, the corresponding pair to use for passing the derivative counterpart is `DifferentialPtrPair`, which represents a pair of `T` and `T.Differential`. Objects of this type can be created using a constructor. + +#### Example of defining and using an `IDifferentiablePtrType` object. +Here is an example of create a differentiable buffer pointer type, and using it within a differentiable function. +You can find an interactive sample on the Slang playground [here](https://shader-slang.org/slang-playground/?target=WGSL&code=eJy1VF1v2kAQfPevWEWKYhfkmFdMkBrRSpHKhyBSpdIIHfgcTjFn9z4gEeK_d-_ONsZp1L6UF8MxOzszuz62K3KhoMjI27PINU9iTyqhNwrGb_c6TamY5YwrKqAPDyNmDihXjKwzOlPi8a2g3tED_NzewuOWQnKGZFAQpGYSCM_VFikYl4rwDYU8bdOHlkQhH8kYkTBq8ty10bFn4fPvC6tVC5q4_wdplhM1hLVOYwvRiMd2qaQq9k5Yhzq_Mf4CBDZaqnwHCRVsTxTbU295TzYvByKSUX3mI1-yWh-S4Mmz3GAO_HY4Rdd1Yjyhr0EZiaCojEMRopplEToV0HGgJ5Rj1UxyRUFtkRkzgnWpoCELJHvmxJg0WUrFsgwThRvGb9pxM8xxn7MEKtV-M0cc2Awhg5b44aX6LjifyVSryknbrmm7KpTAyRRh4pKuzqzb-kfLNHTuLLE1v7zcpypgqXfTdPFLE0HlIMPiCe4e9h2-T73SVxbmEgVFYVqux3JMXh8QJ_0JTs_icgG-s2qQMT4GMMFHpxNYgKM7U-5JtjJQO3SMiQVxjTDt0I6DfHJP9-_Ja84fcc7u-SXr9-efJyO_F6Guj5eY8UIrrL2s_PFlPl38rdRujym121AItDwmjPsfDdTN8ug6diE6OSO20L_6yoRU0IuMR01lH37yqzKIPybaixqRNniukz5cp1iMQXbLTJUwqQblyErgQu_MHSHdEpivaVtCyXOxLL1oaAhrtndra1Ip9_boInJeqxtsRiTeVvZFMk0LV4dH0g3D3VL4Xq3Mgvvt5oFfO_6X986Zr0UF3bq6F0Zlvs1UzreSEYc9dabgEIrQXR3_ZT61unJKp98JDfhi). ```csharp -DifferentialPair derivative(DifferentialPair p0, inout DifferentialPair p1, T2 p2, DifferentialPtrPair p3); -``` +struct MyBufferPointer : IDifferentiablePtrType +{ + // The differential part is another instance of MyBufferPointer. + typealias Differential = MyBufferPointer; -This forward propagation function takes the initial primal value of `p0` in `p0.p`, and the partial derivative of `p0` with regard to some upstream parameter in `p0.d`. It takes the initial primal and derivative values of `p1` and updates `p1` to hold the newly computed value and propagated derivative. Since `p2` is not differentiable, it remains unchanged. + RWStructuredBuffer buf; + uint offset; +}; -`DifferentialPair` is a built-in type that carries both the original and derivative value of a term. It is defined as follows: -```csharp -struct DifferentialPair : IDifferentiable +// Link a custom derivative +[BackwardDerivative(load_bwd)] +float load(MyBufferPointer p, uint index) { - typealias Differential = DifferentialPair; - property T p {get;} - property T.Differential d {get;} - static Differential dzero(); - static Differential dadd(Differential a, Differential b); + return p.buf[p.offset + index]; } -``` -For ptr-types, there is a corresponding built-in `DifferentialPtrPair` that does not have the `dzero` or `dadd` methods. +// Note that the backward derivative signature is still an 'in' differential pair. +void load_bwd(DifferentialPtrPair p, uint index, float dOut) +{ + MyBufferPointer diff_ptr = p.d; + diff_ptr.buf[diff_ptr.offset + index] += dOut; +} -### Automatic Implementation of Forward Derivative Functions +[Differentiable] +float sumOfSquares(MyBufferPointer p) +{ + float sos = 0.f; -A function can be made forward-differentiable with a `[ForwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the forward propagation function. The syntax for using `[ForwardDifferentiable]` is: + [MaxIters(N)] + for (uint i = 0; i < N; i++) + { + float val_i = load(p, i); + sos += val_i * val_i; + } -```csharp -[ForwardDifferentiable] -R original(T0 p0, inout T1, p1, T2 p2); -``` + return sos; +} -Once the function is made forward-differentiable, the forward propagation function can then be called with the `fwd_diff` operator: -```csharp -DifferentialPair result = fwd_diff(original)(...); +RWStructuredBuffer inputs; +RWStructuredBuffer derivs; + +void main() +{ + MyBufferPointer ptr = {inputs, 0}; + print("Sum of squares of first 10 values: ", sumOfSquares<10>(ptr)); + + MyBufferPointer deriv_ptr = {derivs, 0}; + + // Pass a pair of pointers as input. + bwd_diff(sumOfSquares<10>)( + DifferentialPtrPair(ptr, deriv_ptr), + 1.0); + + print("Derivative of result w.r.t the 10 values: \n"); + for (uint i = 0; i < 10; i++) + print("%d: %f\n", i, load(deriv_ptr, i)); +} ``` -### User Defined Forward Derivative Functions -As an alternative to compiler-implemented forward derivatives, the user can choose to manually provide a derivative implementation to make an existing function forward-differentiable. The `[ForwardDerivative(derivative_func)]` attribute is used to associate a function with its forward derivative propagation implementation. The syntax for using `[ForwardDerivative]` attribute is: +## User-Defined Derivative Functions + +As an alternative to compiler-generated derivatives, you can choose to provide an implementation for the derivative, which the compiler will use instead of attempting to generate one. + +This can be performed on a per-function basis by using the decorators `[ForwardDerivative(fwd_deriv_func)]` and `[BackwardDerivative(bwd_deriv_func)]` to reference the derivative from the primal function. + +For instance, it often makes little sense to differentiate the body of a `sin(x)` implementation, when we know that the derivative is `cos(x) * dx`. In Slang, this can be represented in the following way: ```csharp -DifferentialPair derivative(DifferentialPair p0, inout DifferentialPair p1, T2 p2) +DifferentialPair sin_fwd(DifferentialPair dpx) { - .... + float x = dpx.p; + float dx = dpx.d; + return DifferentialPair(dpx.p, cos(x) * dx); } -[ForwardDerivative(derivative)] -R original(T0 p0, inout T1, p1, T2 p2); -``` -If `derivative` is defined in a different scope from `original`, such as in a different namespace or `struct` type, a fully qualified name is required. For example: -```csharp -struct MyType +// sin() is now considered differentiable (atleast for forward-mode) since it provides +// a derivative implementation. +// +[ForwardDerivative(sin_fwd)] +float sin(float x) { - // Implementing derivative function in a different name scope. - static DifferentialPair derivative(DifferentialPair p0, inout DifferentialPair p1, T2 p2) - { - .... - } + // Calc sin(X) using Taylor series.. } -// Use fully qualified name in the attribute. -[ForwardDerivative(MyType.derivative)] -R original(T0 p0, inout T1, p1, T2 p2); +// Any uses of sin() in a `[Differentiable]` will automaticaly use the sin_fwd implementation when differentiated. ``` -Sometimes the derivative function needs to be defined in a different module from the original function, or the derivative function cannot be made visible from the original function. In this case, we can use the `[ForwardDerivativeOf(originalFunc)]` attribute to inform the compiler that `originalFunc` should be treated as a forward-differentiable function, and the current function is the derivative implementation of `originalFunc`. The following code will have the same effect to associate `derivative` and the forward-derivative implementation of `original`: - +A similar example for a backward derivative. ```csharp -R original(T0 p0, inout T1, p1, T2 p2); +void sin_bwd(inout DifferentialPair dpx, float dresult) +{ + float x = dpx.p; + + // Write-back the derivative to each input (the primal part must be copied over as-is) + dpx = DifferentialPair(x, cos(x) * dresult); +} -[ForwardDerivativeOf(original)] -DifferentialPair derivative(DifferentialPair p0, inout DifferentialPair p1, T2 p2) +[BackwardDerivative(sin_bwd)] +float sin(float x) { - .... + // Calc sin(X) using Taylor series.. } ``` -## Backward Derivative Propagation Function +> Note that the signature of the provided forward or backward derivative function must match the expected signature from invoking `fwd_diff(fn)`/`bwd_diff(fn)` +> For a full list of signature rules, see the reference section for the [auto-diff operators](#fwd_difff--slang_function---slang_function). -A backward derivative propagation function propagates the derivative of the function output to all the input parameters simultaneously. +### Back-referencing User Derivative Attributes. +Sometimes, the original function's definition might be inaccessible, so it can be tricky to add an attribute to create the association. -Given an original function `f`, the general rule for determining the signature of its backward propagation function is that a differentiable output `o` becomes an input parameter holding the partial derivative of a downstream output with regard to the differentiable output, i.e. $$\partial y/\partial o$$); an input differentiable parameter `i` in the original function will become an output in the backward propagation function, holding the propagated partial derivative $$\partial y/\partial i$$; and any non-differentiable outputs are dropped from the backward propagation function. This means that the backward propagation function never returns any values computed in the original function. +For such cases, Slang provides the `[ForwardDerivativeOf(primal_fn)]` and `[BackwardDerivativeOf(primal_fn)]` attributes that can be used +on the derivative function and contain a reference to the function for which they are providing a derivative implementation. +As long as both the derivative function is in scope, the primal function will be considered differentiable. -More specifically, the signature of its backward propagation function is determined using the following rules: -- A backward propagation function always returns `void`. -- A differentiable `in` parameter of type `T : IDifferentiable` will become an `inout DifferentialPair` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The original value will not be overwritten by the backward propagation function. The propagated derivative will be written to the derivative part of the differential pair after the backward propagation function returns. The initial derivative value of the pair is ignored as input. -- A differentiable `out` parameter of type `T : IDifferentiable` will become an `in T.Differential` parameter, carrying the partial derivative of some downstream term with regard to the return value. -- A differentiable `inout` parameter of type `T : IDifferentiable` will become an `inout DifferentialPair` parameter, where the original value of the argument, along with the downstream partial derivative with regard to the argument is passed as input to the backward propagation function as the original and derivative part of the pair. The propagated derivative with regard to this input parameter will be written back and replace the derivative part of the pair. The primal value part of the parameter will *not* be updated. -- A differentiable return value of type `R` will become an additional `in R.Differential` parameter at the end of the backward propagation function parameter list, carrying the result derivative of a downstream term with regard to the return value of the original function. -- A non-differentiable return value of type `NDR` will be dropped. -- A non-differentiable `in` parameter of type `ND` will remain unchanged in the backward propagation function. -- A non-differentiable `out` parameter of type `ND` will be removed from the parameter list of the backward propagation function. -- A non-differentiable `inout` parameter of type `ND` will become an `in ND` parameter. -- Types implemented `IDifferentiablePtrType` work the same was as the forward-mode case. They can only be used with `in` parameters, and are converted into `DifferentialPtrPair` types. Their directions are not affected. - -For example consider the following original function: +Example: ```csharp -struct T : IDifferentiable {...} -struct R : IDifferentiable {...} -struct P : IDifferentiablePtrType {...} -struct ND {} // Non differentiable +// Module A +float sin(float x) { /* ... */ } -[Differentiable] -R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5, P p6); +// Module B +import A; +[BackwardDerivativeOf(sin)] // Add a derivative implementation for sin() in module A. +void sin_bwd(inout DifferentialPair dpx, float dresult) { /* ... */ } ``` -The signature of its backward propagation function is: -```csharp -void back_prop( - inout DifferentialPair p0, - T.Differential p1, - inout DifferentialPair p2, - ND p3, - ND p5, - DifferentialPtrPair

p6, - R.Differential dResult); -``` -Note that although `p2` is still `inout` in the backward propagation function, the backward propagation function will only write propagated derivative to `p2.d` and will not modify `p2.p`. -### Automatically Implemented Backward Propagation Functions +User-defined derivatives also work for generic functions, member functions, accessors, and more. +See the reference section for the [`[ForwardDerivative(fn)]`](https://shader-slang.org/stdlib-reference/attributes/forwardderivative-07.html) and [`[BackwardDerivative(fn)]`](https://shader-slang.org/stdlib-reference/attributes/backwardderivative-08) attributes for more. -A function can be made backward-differentiable with a `[Differentiable]` or `[BackwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the backward propagation function. The syntax for using `[Differentiable]` is: +## Using Auto-diff with Generics +Automatic differentiation works seamlessly with generically-defined types and methods. +For generic methods, differentiability of a type is defined either through an explicit `IDifferentiable` constraint or any other +interface that extends `IDifferentiable`. +Example for generic methods: ```csharp [Differentiable] -R original(T0 p0, inout T1, p1, T2 p2); +T calcFoo(T x) { /* ... */ } + +[Differentiable] +T calcBar(T x) { /* ... */ } + +[Differentiable] +void main() +{ + DifferentialPair dpa = /* ... */; + + // Can call with any type that is IDifferentiable. Generic parameters + // are inferred like any other call. + // + bwd_diff(calcFoo)(dpa, float4(1.f)); + + // But you can also be explicit with < > + bwd_diff(calcFoo)(dpa, float4(1.f)); + + // x is differentiable for calcBar because + // __BuiltinFloatingPointType : IDifferentiable + // + DifferentialPair dpb = /* .. */; + bwd_diff(calcBar)(dpb, 1.0); +} ``` -Once the function is made backward-differentiable, the backward propagation function can then be called with the `bwd_diff` operator: +You can implement `IDifferentiable` on a generic type. Automatic synthesis still applies and will use +generic constraints to resolve whether a field is differentiable or not. ```csharp -bwd_diff(original)(...); +struct Foo : IDifferentiable +{ + T t; + U u; +}; + +// The synthesized Foo.Differential will contain a field for +// 't' but not 'U' +// ``` -### User Defined Backward Propagation Functions -Similar to user-defined forward derivative functions, the `[BackwardDerivative]` and `[BackwardDerivativeOf]` attributes can be used to supply a function with user defined backward propagation function. +## Using Auto-diff with Interface Requirements and Interface Types +For interface requirements, using `[Differentiable]` attribute enforces that any implementation of that method must also be +differentiable. You can, of course, provide a manual derivative implementation to satisfy the requirement. -The syntax for using `[BackwardDerivative]` attribute is: +The following is a sample snippet. You can run the full sample on the playground [here](https://shader-slang.org/slang-playground/?target=HLSL&code=eJyVVMtu2zAQvOsrFgEKy4Wq1C7QQ1330AYBcujjnhbBWiRjphQpUJQjI8i_d0mRquLYASLYtLwc7sySs5R1Y6yDRuH-1ppOs1UmteNWYMXh6tKY7CEDeq4vpBDccu0kbhT_E4JCGXRQoary4bWfr7LHLGud7SoHtPqqbhR8miYagJTeGbvKQuj8HDyO15QdnTQadhIBO2dq-lsB-0_tZ8tXCQrxgdo_lrvOaujhLXwoBY1RCQTE43P1y5fk-8gLJdSoO1RqD401O8mkvgXGrdwRYseh5m5rWBvLuTT2Hi27GOdzX8aNuGfzobbrr1j9PQbZjJBXlb-clh-rDz-TjVW_UNrPIdcXSHryU4BTbP78PC7vy2YgLiJ7X7JRw_yJiJ2RDFJ5udSmcyeF9UWsnFnedsodyuhhfWqtl1SkdQebMgoiSd4BcMvdz81d3lGDgNnc3Ug2j66QAvIhAus1LA4FpEaQfljDw6L8KB5Xh9vkZxOlH7lq-fFEyzET6X05EWl_1inRJqZuOseTUwo4UtfxZv1mOToOSEy6dajppjACuHRbbpPEBZjxfRkWhi0U9F2njYxcsfdS9ou9xjp0fdugq7bgDGBDHdRY6Wln3hWzMsLTqh-GptzWqyUK6VquBMgWtNHv2JP6CxLORrYtesz0ilHA0GEBG3LcrJ8F9GwwyCytupdKkTut3U8aui3bqagdWoi-WntRZehLf0Nmk7MaEOHWDJanIrX7jlLn6QxOuZ41SInH3lqW76NhqWFufDiPJzzPCVrAgj6lSPSBJz97I37rs8LnKpm_u_8BU5nW2Q). ```csharp -void back_prop( - inout DifferentialPair p0, - T1.Differential p1, - inout DifferentialPair p2, - ND p3, - ND p5, - DifferentialPtrPair

p6, - R.Differential dResult) +interface IFoo { - ... + [Differentiable] + float calc(float x); } -[BackwardDerivative(back_prop)] -R original(T0 p0, inout T1, p1, T2 p2); -``` +struct FooImpl : IFoo +{ + // Implementation via automatic differentiation. + [Differentiable] + float calc(float x) + { /* ... */ } +} -Similarly, the `[BackwardDerivativeOf]` attribute can be used on the back-prop function in case it is not convenient to modify the definition of the original function, or the back-prop function can't be made visible from the original function: +struct FooImpl2 : IFoo +{ + // Implementation via manually providing derivative methods. + [ForwardDerivative(calc_fwd)] + [BackwardDerivative(calc_bwd)] + float calc(float x) + { /* ... */ } -```csharp -R original(T0 p0, inout T1, p1, T2 p2); + DifferentialPair calc_fwd(DifferentialPair x) + { /* ... */ } -[BackwardDerivativeOf(original)] -void back_prop( - inout DifferentialPair p0, - T1.Differential p1, - inout DifferentialPair p2, - ND p3, - ND p5, - DifferentialPtrPair

p6, - R.Differential dResult) + void calc_bwd(inout DifferentialPair x, float dresult) + { /* ... */ } +} + +[Differentiable] +float compute(float x, uint obj_id) { - ... + // Create an instance of either FooImpl1 or FooImpl2 + IFoo foo = createDynamicObject(obj_id); + + // Dynamic dispatch to appropriate 'calc'. + // + // Note that foo itself is non-differentiable, and + // has no differential data, but 'x' and 'result' + // will carry derivatives.s + // + var result = foo.calc(x); + return result; } ``` -## Builtin Differentiable Functions +### Differentiable Interface (and Associated) Types +> Note: This is an advanced use-case and support is currently experimental. -The following built-in functions are backward differentiable and both their forward-derivative and backward-propagation functions are already defined in the core module: +You can have an interface or an interface associated type extend `IDifferentiable` and use that in differentiable interface requirement functions. This is often important in large code-bases with modular components that are all differentiable (one example is the material system in large production renderers) -- Arithmetic functions: `abs`, `max`, `min`, `sqrt`, `rcp`, `rsqrt`, `fma`, `mad`, `fmod`, `frac`, `radians`, `degrees` -- Interpolation and clamping functions: `lerp`, `smoothstep`, `clamp`, `saturate` -- Trigonometric functions: `sin`, `cos`, `sincos`, `tan`, `asin`, `acos`, `atan`, `atan2` -- Hyperbolic functions: `sinh`, `cosh`, `tanh` -- Exponential and logarithmic functions: `exp`, `exp2`, `pow`, `log`, `log2`, `log10` -- Vector functions: `dot`, `cross`, `length`, `distance`, `normalize`, `reflect`, `refract` -- Matrix transforms: `mul(matrix, vector)`, `mul(vector, matrix)`, `mul(matrix, matrix)` -- Matrix operations: `transpose`, `determinant` -- Legacy blending and lighting intrinsics: `dst`, `lit` - -## Primal Substitute Functions +Here is a snippet of how to make an interface and associated type (and by consequence all its implementations) differentiable. +For a full working sample, check out the Slang playground [here](https://shader-slang.org/slang-playground/?target=WGSL&code=eJylVVFvmzAQfudXnCpVgoXRhK4vpdnLuodIq9pqe9umygGzuXPANaYjqvLfd8bgmIR0a-eHcHfcnT9_38WwlSilAsHJ-ocs6yJLPI8VisqcpBQWHwhPa04UKws4h8Uly3MqaaEYWXLqPXmA6-sw-r0N5rwkCogQfO0buw4SbzPs_kWSospLuTrYm1RVmTKiaKbWgsIOHMfFxgexuFUr8ov6HZJKyTpV8IkVlMibkq-LcsUI3-ncIekOlDjO0jgvIaEJ4AkkVbUsgMAbaGCCbWDjbSyc25pkEndOX4_IOOlznDwPzbfYgs5IhyANZwstpSi3glhBO4haNMIZqQYazPcod2ELNRu68cu0bcNme7321CUpAnjy1U9WRdgc3kJnzoLQmpvENujVSk1oVKpXEzEitjNUhwnZuiuW3Qn1XxSNTdxDy5JN0SvGUdCETeBda82QOm0ZBOEgdxvHpDObjuXDPAxbf5_zB5fzvbM514c-1vXy3q_xcgGWhR03j4TPHDsOOjVYDj7LYD6H6fi4DPXkTHNhmuk2-0A564HqX8oreojiYeeHnc4h-JplHUCaW8hwAqdRPsINe46b7gZA5Q0nev56JoTlRMS91fTUOKSWy3tE11NrOuhaEQfduA0-D3pgsCTqb1gHaxqZi6bRF6_nPZZIvpCI64qwwu-3boFeXV9-_Iyd-hkvJXSqYnCa4OPC5KA5meyqZw7TfmHEDU4clkRxcuB1jK9P3dctJP9oKNFVmVE4zsBv5tPoLDiH4_xbcRQCCw29-LT7bU0kVmf3ROnlSMRvCJMXLZr3kIkGgWT4Vkd9XZb8Q9HCOaQttkhe1CIeaxE7LZa_szud4OsTB_rIzv6uE2unCWEWTd2jd8RmvqRVzVVwkuEoWCaxIsqCPRncbH05O_l277_XxWN1sa3Df88fIn-viQ) -Sometimes it is desirable to replace a function with another when generating forward or backward derivative propagation code. For example, the following code shows a function that computes the integral of some term by sampling and we want to use a different sampling strategy when computing the derivatives. ```csharp -float myTerm(float x) +interface IFoo : IDifferentiable { - return someComplexComputation(x); -} + associatedtype BaseType : IDifferentiable; -float getSample(float a, float b) { ... } + [Differentiable] + BaseType foo(BaseType x); +}; [Differentiable] -float computeIntegralOverMyTerm(float x, float a, float b) +float calc(float x) { - float sum = 0.0; - for (int i = 0; i < SAMPLE_COUNT; i++) - { - let s = no_diff getSample(a, b); - let y = myTerm(s); - sum += y * ((b-a)/SAMPLE_COUNT); - } - return sum; + // Note that since IFoo is differentiable, + // any data in the IFoo implementation is differentiable + // and will carry derivatives. + // + IFoo obj = makeObj(/* ... */); + return obj.foo(x); } ``` -In this code, the `getSample` function returns a random sample in the range of `[a,b]`. Assume we have another sampling function `getSampleForDerivativeComputation(a,b)` that we wish to use instead in derivative computation, we can do so by marking it as a primal-substitute of `getSample`, as in the following code: -```csharp -[PrimalSubstituteOf(getSample)] -float getSampleForDerivativeComputation(float a, float b) -{ - ... -} -``` - -Here, the `[PrimalSubstituteOf(getSample)]` attributes marks the `getSampleForDerivativeComputation` function as the substitute for `getSample` in derivative propagation functions. When a function has a primal substitute, the compiler will treat all calls to that function as if it is a call to the substitute function when generating derivative code. Note that this only applies to compiler generated derivative function and does not affect user provided derivative functions. If a user provided derivative function calls `getSample`, it will not be replaced by `getSampleForDerivativeComputation` by the compiler. +Under the hood, Slang will automatically construct an anonymous abstract type to represent the differentials. +However, on targets that don't support true dynamic dispatch, these are lowered into tagged unions. +While we are working to improve the implementation, this union can currently include all active differential +types, rather than just the relevant ones. This can lead to increased memory use. -Similar to `[ForwardDerivative]` and `[ForwardDerivativeOf]` attributes, The `[PrimalSubstitute(substFunc)]` attribute works the other way around: it specifies the primal substitute function of the function being marked. - -Primal substitute can be used as another way to make a function differentiable. A function is considered differentiable if it has a primal substitute that is differentiable. The following code illustrates this mechanism. -```csharp -float myFunc(float x) {...} +## Primal Substitute Functions -[PrimalSubstituteOf(myFunc)] -[Differentiable] -float myFuncSubst(float x) {...} +Sometimes it is desirable to replace a function with another when generating derivative code. +Most often, this is because a lot of shader operations may just not have a function body, such hardware intrinsics for +texture sampling. In such cases, Slang provides a `[PrimalSubstitute(fn)]` attribute that can be used to provide +a reference implementation that Slang can differentiate to generate the derivative function. -// myFunc is now considered backward differentiable. -``` +The following is a small snippet with bilinear texture sampling. For a full example application that uses this concept, see the [texture differentiation sample](https://github.com/shader-slang/slang/tree/master/examples/autodiff-texture) in the Slang repository. -The following example shows in more detail on how primal substitute affects derivative computation. ```csharp -float myFunc(float x) { return x*x; } - -[PrimalSubstituteOf(myFunc)] -[ForwardDifferentiable] -float myFuncSubst(float x) { return x*x*x; } +[PrimalSubstitute(sampleTextureBiliear_reference)] +float4 sampleTextureBilinear(Texture2D x, float2 loc) +{ + // HW-accelerated sampling intrinsics. + // Slang does not have access to body, so cannot differentiate. + // + x.Sample(/*...*/) +} -[ForwardDifferentiable] -float caller(float x) { return myFunc(x); } +// Since the substitute is differentiable, so is `sampleTextureBilinear`. +[Differentiable] +float4 sampleTextureBilinear_reference(Texture2D x, float2 loc) +{ + // Reference SW interpolation, that is differentiable. +} -let a = caller(4.0); // a == 16.0 (calling myFunc) -let b = fwd_diff(caller)(diffPair(4.0, 1.0)).p; // b == 64.0 (calling myFuncSubst) -let c = fwd_diff(caller)(diffPair(4.0, 1.0)).d; // c == 48.0 (calling derivative of myFuncSubst) +[Differentiable] +float computePixel(Texture2D x, float a, float b) +{ + // Slang will use HW-accelerated sampleTextureBilinear for standard function + // call, but differentiate the SW reference interpolation during backprop. + // + float4 sample1 = sampleTextureBilinear(x, float2(a, 1)); +} ``` -In case that a function has both custom defined derivatives and a differentiable primal substitute, the primal substitute overrides the custom defined derivative on the original function. All calls to the original function will be translated into calls to the primal substitute first, and differentiation step follows after. This means that the derivatives of the primal substitute function will be used instead of the derivatives defined on the original function. +Similar to `[ForwardDerivativeOf(fn)]` and `[BackwardDerivativeOf(fn)]` attributes, Slang provides a `[PrimalSubstituteOf(fn)]` attribute that can be used on the substitute function to reference the primal one. ## Working with Mixed Differentiable and Non-Differentiable Code @@ -655,7 +664,7 @@ struct MyType : IDifferentiable no_diff float member; float someOtherMember; } -[ForwardDifferentiable] +[Differentiable] float f(float x) { MyType t; @@ -668,7 +677,7 @@ let result = fwd_diff(f)(diffPair(3.0, 1.0)).d; // result == 0.0 In this case, we are assigning the value `x*x`, which carries a derivative, into a non-differentiable location `MyType.member`, thus throwing away any derivative info. When `f` returns `t.member`, there will be no derivative associated with it, so the function will not propagate the derivative through. This code is most likely not intending to discard the derivative through the assignment. To help avoid this kind of unintentional behavior, Slang will treat any assignments of a value with derivative info into a non-differentiable location as a compile-time error. To eliminate this error, the user should either make `t.member` differentiable, or to force the assignment by clarifying the intention to discard any derivatives using the built-in `detach` method. The following code will compile, and the derivatives will be discarded: ```csharp -[ForwardDifferentiable] +[Differentiable] float f(float x) { MyType t; @@ -688,7 +697,7 @@ float g(float x) return 2*x; } -[ForwardDifferentiable] +[Differentiable] float f(float x) { // Error: implicit call to non-differentiable function g. @@ -698,7 +707,7 @@ float f(float x) The derivative will not propagate through the call to `g` in `f`. As a result, `fwd_diff(f)(diffPair(1.0, 1.0))` will return `{3.0, 2.0}` instead of `{3.0, 4.0}` as the derivative from `2*x` is lost through the non-differentiable call. To prevent unintended error, it is treated as a compile-time error to call `g` from `f`. If such a non-differentiable call is intended, a `no_diff` prefix is required in the call: ```csharp -[ForwardDifferentiable] +[Differentiable] float f(float x) { // OK. The intention to call a non-differentiable function is clarified. @@ -743,7 +752,7 @@ float result = fwd_diff(use)(obj, diffPair(2.0, 1.0)).d; // being generated regardless of the original code. ``` -## Higher Order Differentiation +## Higher-Order Differentiation Slang supports generating higher order forward and backward derivative propagation functions. It is allowed to use `fwd_diff` and `bwd_diff` operators inside a forward or backward differentiable function, or to nest `fwd_diff` and `bwd_diff` operators. For example, `fwd_diff(fwd_diff(sin))` will have the following signature: @@ -755,18 +764,95 @@ The input parameter `x` contains four fields: `x.p.p`, `x.p.d,`, `x.d.p`, `x.d.d User defined higher-order derivative functions can be specified by using `[ForwardDerivative]` or `[BackwardDerivative]` attribute on the derivative function, or by using `[ForwardDerivativeOf]` or `[BackwardDerivativeOf]` attribute on the higher-order derivative function. -## Interactions with Generics and Interfaces +## Restrictions and Known Issues -Automatic differentiation for generic functions is supported. The forward-derivative and backward propagation functions of a generic function is also a generic function with the same set of generic parameters and constraints. Using `[ForwardDerivative]`, `[ForwardDerivativeOf]`, `[BackwardDerivative]` or `[BackwardDerivativeOf]` attributes to associate a derivative function with different set of generic parameters or constraints is a compile-time error. +The compiler can generate forward derivative and backward propagation implementations for most uses of array and struct types, including arbitrary read and write access at dynamic array indices, and supports uses of all types of control flows, mutable parameters, generics and interfaces. This covers the set of operations that is sufficient for a lot of functions. However, the user needs to be aware of the following restrictions when using automatic differentiation: -An interface method requirement can be marked as `[ForwardDifferentiable]` or `[Differentiable]`, so they may be called in a forward or backward differentiable function and have the derivatives propagate through the call. This works regardless of whether the call can be specialized or has to go through dynamic dispatch. However, calls to interface methods are only differentiable once. Higher order differentiation through interface method calls are not supported. +- All operations to global resources, global variables and shader parameters, including texture reads or atomic writes, are treated as a non-differentiable operation. Slang provides support for special data-structures (such as `Tensor`) through libraries such as `SlangPy`, which come with custom derivative implementations +- If a differentiable function contains calls that cause side-effects such as updates to global memory, there is currently no guarantee on how many times side-effects will occur during the resulting derivative function or back-propagation function. +- Loops: Loops must have a bounded number of iterations. If this cannot be inferred statically from the loop structure, the attribute `[MaxIters()]` can be used specify a maximum number of iterations. This will be used by compiler to allocate space to store intermediate data. If the actual number of iterations exceeds the provided maximum, the behavior is undefined. You can always mark a loop with the `[ForceUnroll]` attribute to instruct the Slang compiler to unroll the loop before generating derivative propagation functions. Unrolled loops will be treated the same way as ordinary code and are not subject to any additional restrictions. +- Double backward derivatives (higher-order differentiation): The compiler does not currently support multiple backward derivative calls such as `bwd_diff(bwd_diff(fn))`. The vast majority of higher-order derivative applications can be acheived more efficiently via multiple forward-derivative calls or a single layer of `bwd_diff` on functions that use one or more `fwd_diff` passes. -## Restrictions of Automatic Differentiation +The above restrictions do not apply if a user-defined derivative or backward propagation function is provided. -The compiler can generate forward derivative and backward propagation implementations for most uses of array and struct types, including arbitrary read and write access at dynamic array indices, and supports uses of all types of control flows, mutable parameters, generics and interfaces. This covers the set of operations that is sufficient for a lot of functions. However, the user needs to be aware of the following restrictions when using automatic differentiation: +## Reference -- All operations to global resources, global variables and shader parameters, including texture reads or atomic writes, are treating as a non-differentiable operation. -- If a differentiable function contains calls that cause side-effects such as updates to global memory, there will not be a guarantee on how many times the side-effect will occur during the resulting derivative function or back-propagation function. -- Loops: Loops must use the attribute `[MaxIters()]` to specify a maximum number of iterations. This will be used by compiler to allocate space to store intermediate data. If the actual number of iterations exceeds the provided maximum, the behavior is undefined. You can always mark a loop with the `[ForceUnroll]` attribute to instruct the Slang compiler to unroll the loop before generating derivative propagation functions. Unrolled loops will be treated the same way as ordinary code and are not subject to any additional restrictions. +This section contains some additional information for operators that are not currently included in the [standard library reference](https://shader-slang.org/stdlib-reference/) -The above restrictions do not apply if a user-defined derivative or backward propagation function is provided. +### `fwd_diff(f : slang_function) -> slang_function` +The `fwd_diff` operator can be used on a differentiable function to obtain the forward derivative propagation function. + +A forward derivative propagation function computes the derivative of the result value with regard to a specific set of input parameters. +Given an original function, the signature of its forward propagation function is determined using the following rules: +- If the return type `R` implements `IDifferentiable` the forward propagation function will return a corresponding `DifferentialPair` that consists of both the computed original result value and the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`. +- If a parameter has type `T` that implements `IDifferentiable`, it will be translated into a `DifferentialPair` parameter in the derivative function, where the differential component of the `DifferentialPair` holds the initial derivatives of each parameter with regard to their upstream parameters. +- If a parameter has type `T` that implements `IDifferentiablePtrType`, it will be translated into a `DifferentialPtrPair` parameter where the differential component references the differential component. +- All parameter directions are unchanged. For example, an `out` parameter in the original function will remain an `out` parameter in the derivative function. +- Differentiable methods cannot have a type implementing `IDifferentiablePtrType` as an `out` or `inout` parameter, or a return type. Types implementing `IDifferentiablePtrType` can only be used for input parameters to a differentiable method. Marking such a method as `[Differentiable]` will result in a compile-time diagnostic error. + +For example, given original function: +```csharp +[Differentiable] +R original(T0 p0, inout T1 p1, T2 p2, T3 p3); +``` +Where `R`, `T0`, `T1 : IDifferentiable`, `T2` is non-differentiable, and `T3 : IDifferentiablePtrType`, the forward derivative function will have the following signature: +```csharp +DifferentialPair derivative(DifferentialPair p0, inout DifferentialPair p1, T2 p2, DifferentialPtrPair p3); +``` + +This forward propagation function takes the initial primal value of `p0` in `p0.p`, and the partial derivative of `p0` with regard to some upstream parameter in `p0.d`. It takes the initial primal and derivative values of `p1` and updates `p1` to hold the newly computed value and propagated derivative. Since `p2` is not differentiable, it remains unchanged. + +### `bwd_diff(f : slang_function) -> slang_function` + +A backward derivative propagation function propagates the derivative of the function output to all the input parameters simultaneously. + +Given an original function `f`, the general rule for determining the signature of its backward propagation function is that a differentiable output `o` becomes an input parameter holding the partial derivative of a downstream output with regard to the differentiable output, i.e. $\partial y/\partial o$; an input differentiable parameter `i` in the original function will become an output in the backward propagation function, holding the propagated partial derivative $\partial y/\partial i$; and any non-differentiable outputs are dropped from the backward propagation function. This means that the backward propagation function never returns any values computed in the original function. + +More specifically, the signature of its backward propagation function is determined using the following rules: +- A backward propagation function always returns `void`. +- A differentiable `in` parameter of type `T : IDifferentiable` will become an `inout DifferentialPair` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The original value will not be overwritten by the backward propagation function. The propagated derivative will be written to the derivative part of the differential pair after the backward propagation function returns. The initial derivative value of the pair is ignored as input. +- A differentiable `out` parameter of type `T : IDifferentiable` will become an `in T.Differential` parameter, carrying the partial derivative of some downstream term with regard to the return value. +- A differentiable `inout` parameter of type `T : IDifferentiable` will become an `inout DifferentialPair` parameter, where the original value of the argument, along with the downstream partial derivative with regard to the argument is passed as input to the backward propagation function as the original and derivative part of the pair. The propagated derivative with regard to this input parameter will be written back and replace the derivative part of the pair. The primal value part of the parameter will *not* be updated. +- A differentiable return value of type `R` will become an additional `in R.Differential` parameter at the end of the backward propagation function parameter list, carrying the result derivative of a downstream term with regard to the return value of the original function. +- A non-differentiable return value of type `NDR` will be dropped. +- A non-differentiable `in` parameter of type `ND` will remain unchanged in the backward propagation function. +- A non-differentiable `out` parameter of type `ND` will be removed from the parameter list of the backward propagation function. +- A non-differentiable `inout` parameter of type `ND` will become an `in ND` parameter. +- Types implemented `IDifferentiablePtrType` work the same was as the forward-mode case. They can only be used with `in` parameters, and are converted into `DifferentialPtrPair` types. Their directions are **not** affected. + +For example consider the following original function: +```csharp +struct T : IDifferentiable {...} +struct R : IDifferentiable {...} +struct P : IDifferentiablePtrType {...} +struct ND {} // Non differentiable + +[Differentiable] +R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5, P p6); +``` +The signature of its backward propagation function is: +```csharp +void back_prop( + inout DifferentialPair p0, + T.Differential p1, + inout DifferentialPair p2, + ND p3, + ND p5, + DifferentialPtrPair

p6, + R.Differential dResult); +``` +Note that although `p2` is still `inout` in the backward propagation function, the backward propagation function will only write propagated derivative to `p2.d` and will not modify `p2.p`. + +### Built-in Differentiable Functions + +The following built-in functions are differentiable and both their forward and backward derivative functions are already defined in the standard library's core module: + +- Arithmetic functions: `abs`, `max`, `min`, `sqrt`, `rcp`, `rsqrt`, `fma`, `mad`, `fmod`, `frac`, `radians`, `degrees` +- Interpolation and clamping functions: `lerp`, `smoothstep`, `clamp`, `saturate` +- Trigonometric functions: `sin`, `cos`, `sincos`, `tan`, `asin`, `acos`, `atan`, `atan2` +- Hyperbolic functions: `sinh`, `cosh`, `tanh` +- Exponential and logarithmic functions: `exp`, `exp2`, `pow`, `log`, `log2`, `log10` +- Vector functions: `dot`, `cross`, `length`, `distance`, `normalize`, `reflect`, `refract` +- Matrix transforms: `mul(matrix, vector)`, `mul(vector, matrix)`, `mul(matrix, matrix)` +- Matrix operations: `transpose`, `determinant` +- Legacy blending and lighting intrinsics: `dst`, `lit` \ No newline at end of file diff --git a/docs/user-guide/toc.html b/docs/user-guide/toc.html index 9f9085cdd..713c5d8e6 100644 --- a/docs/user-guide/toc.html +++ b/docs/user-guide/toc.html @@ -92,17 +92,16 @@

  • Automatic Differentiation
      -
    • Using Automatic Differentiation in Slang
    • -
    • Mathematic Concepts and Terminologies
    • -
    • Differentiable Value Types
    • -
    • Forward Derivative Propagation Function
    • -
    • Backward Derivative Propagation Function
    • -
    • Builtin Differentiable Functions
    • +
    • Auto-diff operations `fwd_diff` and `bwd_diff`
    • +
    • Differentiable Type System
    • +
    • User-Defined Derivative Functions
    • +
    • Using Auto-diff with Generics
    • +
    • Using Auto-diff with Interface Requirements and Interface Types
    • Primal Substitute Functions
    • Working with Mixed Differentiable and Non-Differentiable Code
    • -
    • Higher Order Differentiation
    • -
    • Interactions with Generics and Interfaces
    • -
    • Restrictions of Automatic Differentiation
    • +
    • Higher-Order Differentiation
    • +
    • Restrictions and Known Issues
    • +
    • Reference
  • Compiling Code with Slang diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index d4cce037d..36e9d6885 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -393,15 +393,47 @@ attribute_syntax [__NonCopyableType] : NonCopyableTypeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [__NoSideEffect] : NoSideEffectAttribute; -/// Marks a function for forward-mode differentiation. -/// i.e. the compiler will automatically generate a new function -/// that computes the jacobian-vector product of the original. +/// Marks a function as being differentiable in forward-mode. +/// +/// See the user guide [section on forward-mode differentiation](https://shader-slang.org/slang/user-guide/autodiff.html#invoking-auto-diff-in-slang) for more +/// +/// If used on a function that has a definition (i.e. a function body),Slang will use +/// automatic-differentiation to generate a forward-mode derivative of this function, +/// unless an implementation is provided by the user via `[ForwardDerivative(fn)]` +/// +/// If used on an interface requirement, the signature of the requirement is modified to +/// include forward-differentiability. Any satisfying method must also be forward-differentiable, +/// or provide an appropriate derivative implementation. +/// See the user guide [section on auto-diff for interfaces](https://shader-slang.org/slang/user-guide/autodiff.html##using-auto-diff-with-interface-requirements-and-interface-types) for more +/// __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; -/// Marks a function for backward-mode differentiation. +/// Marks a function as being differentiable for backward-mode auto-diff. +/// Note that in the current implementation, this implies that the method +/// is also forward differentiable. +/// +/// This attribute is equivalent to using `[Differentiable]` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable(order:int = 0)] : BackwardDifferentiableAttribute; + +/// Marks a function as being differentiable for both +/// forward and backward mode auto-diff. +/// +/// This attribute is equivalent to using `[Differentiable]` +/// +/// See the user guide [section on auto-diff invocation](https://shader-slang.org/slang/user-guide/autodiff.html#invoking-auto-diff-in-slang) for more. +/// +/// If used on a function that has a definition (i.e. a function body), Slang will use +/// automatic-differentiation to generate the derivative implementations for this function, +/// unless an implementation is provided by the user via `[ForwardDerivative(fn)]` and/or `[BackwardDerivative(fn)]` +/// +/// If used on an interface requirement, the signature of the requirement is modified to +/// include differentiability. Any satisfying method must also be differentiable, +/// or provide appropriate derivative implementations. +/// See the user guide [section on auto-diff for interfaces](https://shader-slang.org/slang/user-guide/autodiff.html##using-auto-diff-with-interface-requirements-and-interface-types) for more +/// __attributeTarget(FunctionDeclBase) attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribute; @@ -418,18 +450,158 @@ void __requireGLSLExtension(constexpr String preludeText); __intrinsic_op($(kIROp_StaticAssert)) void static_assert(constexpr bool condition, NativeString errorMessage); -/// Represents a type that is differentiable for the purposes of automatic differentiation. +/// Represents a 'value' type that is differentiable for the purposes of automatic differentiation. +/// +/// See the auto-diff user guide section for an introduction to +/// differentiable value types (https://shader-slang.org/slang/user-guide/autodiff.html#differentiable-value-types) +/// +/// #### Builtin Differentiable Value Types +/// The following built-in types are differentiable: +/// - Scalars: `float`, `double` and `half`. +/// - Vector/Matrix: `vector` and `matrix` of `float`, `double` and `half` types. +/// - Arrays: `T[n]` is differentiable if `T` is differentiable. +/// - Tuples: `Tuple` is differentiable if `T` is differentiable. +/// +/// The `IDifferentiable` interface requires the following definitions (which can be auto-generated by the compiler for most scenarios) +/// ```csharp +/// interface IDifferentiable +/// { +/// associatedtype Differential : IDifferentiable +/// where Differential.Differential == Differential; +/// +/// static Differential dzero(); +/// +/// static Differential dadd(Differential, Differential); +/// } +/// ``` +/// +/// As defined by the `IDifferentiable` interface, a differentiable type must have a +/// `Differential` associated type that stores the derivative of the value. +/// A further requirement is that the type of the second-order derivative must be the same +/// `Differential` type. In another words, given a type `T`, `T.Differential` can be different +/// from `T`, but `T.Differential.Differential` must equal to `T.Differential`. +/// +/// In addition, a differentiable type must define the `zero` value of its derivative, +/// and how to add two derivative values together. These function are used during reverse-mode +/// auto-diff, to initialize and accumulate derivatives of the given type. +/// +/// #### Automatic Fulfillment of `IDifferentiable` Requirements +/// Assume the user has defined the following type: +/// +/// ```csharp +/// struct MyRay +/// { +/// float3 origin; +/// float3 dir; +/// int nonDifferentiablePayload; +/// } +/// ``` +/// +/// The type can be made differentiable by adding `IDifferentiable` conformance: +/// ```csharp +/// struct MyRay : IDifferentiable +/// { +/// float3 origin; +/// float3 dir; +/// int nonDifferentiablePayload; +/// } +/// ``` +/// +/// Note that this code does not provide any explicit implementation of the `IDifferentiable` requirements. In this case the compiler will automatically synthesize all the requirements. This should provide the desired behavior most of the time. The procedure for synthesizing the interface implementation is as follows: +/// 1. A new type is generated that stores the `Differential` of all differentiable fields. This new type itself will conform to the `IDifferentiable` interface, and it will be used to satisfy the `Differential` associated type requirement. +/// 2. Each differential field will be associated to its corresponding field in the newly synthesized `Differential` type. +/// 3. The `zero` value of the differential type is made from the `zero` value of each field in the differential type. +/// 4. The `dadd` method invokes the `dadd` operations for each field whose type conforms to `IDifferentiable`. +/// 5. If the synthesized `Differential` type contains exactly the same fields as the original type, and the type of each field is the same as the original field type, then the original type itself will be used as the `Differential` type instead of creating a new type to satisfy the `Differential` associated type requirement. This means that all the synthesized `Differential` type use itself to meet its own `IDifferentiable` requirements. +/// +/// #### Manual fulfilment of `IDifferentiable` requirements +/// In rare cases where more control is desired, the user can manually provide the implementation. +/// To do so, we will first define the `Differential` type for `MyRay`, and use it to fulfill +/// the `Differential` requirement in `MyRay`: +/// ```csharp +/// struct MyRayDifferential +/// { +/// float3 d_origin; +/// float3 d_dir; +/// } +/// +/// struct MyRay : IDifferentiable +/// { +/// // Specify that `MyRay.Differential` is `MyRayDifferential`. +/// typealias Differential = MyRayDifferential; +/// +/// // Specify that the derivative for `origin` will be stored in `MayRayDifferential.d_origin`. +/// [DerivativeMember(MayRayDifferential.d_origin)] +/// float3 origin; +/// +/// // Specify that the derivative for `dir` will be stored in `MayRayDifferential.d_dir`. +/// [DerivativeMember(MayRayDifferential.d_dir)] +/// float3 dir; +/// +/// // This is a non-differentiable field so we don't put any attributes on it. +/// int nonDifferentiablePayload; +/// +/// // Define zero derivative. +/// static MyRayDifferential dzero() +/// { +/// return {float3(0.0), float3(0.0)}; +/// } +/// +/// // Define the add operation of two derivatives. +/// static MyRayDifferential dadd(MyRayDifferential v1, MyRayDifferential v2) +/// { +/// MyRayDifferential result; +/// result.d_origin = v1.d_origin + v2.d_origin; +/// result.d_dir = v1.d_dir + v2.d_dir; +/// return result; +/// } +/// } +/// ``` /// -/// Implemented by builtin floating-point scalar types (`float`, `half`, `double`) +/// Note that for each struct field that is differentiable, we need to use the `[DerivativeMember]` attribute to associate it with the +/// corresponding field in the `Differential` type, so the compiler knows how to access the derivative for the field. +/// +/// However, there is still a missing piece in the above code: we also need to make `MyRayDifferential` conform to `IDifferentiable` because it is required that the `Differential` of a type must itself be `Differential`. Again we can use automatic fulfillment by simply adding `IDifferentiable` conformance to `MyRayDifferential`: +/// ```csharp +/// struct MyRayDifferential : IDifferentiable +/// { +/// float3 d_origin; +/// float3 d_dir; +/// } +/// ``` +/// In this case, since all fields in `MyRayDifferential` are differentiable, and the `Differential` of each field is the same as the original type of each field (i.e. `float3.Differential == float3` as defined in the core module), the compiler will automatically use the type itself as its own `Differential`, making `MyRayDifferential` suitable for use as `Differential` of `MyRay`. +/// +/// We can also choose to manually implement `IDifferentiable` interface for `MyRayDifferential` as in the following code: +/// +/// ```csharp +/// struct MyRayDifferential : IDifferentiable +/// { +/// typealias Differential = MyRayDifferential; +/// +/// [DerivativeMember(MyRayDifferential.d_origin)] +/// float3 d_origin; +/// +/// [DerivativeMember(MyRayDifferential.d_dir)] +/// float3 d_dir; +/// +/// static MyRayDifferential dzero() +/// { +/// return {float3(0.0), float3(0.0)}; +/// } /// -/// vector, matrix and Array automatically conform to -/// `IDifferentiable` if `T` conforms to `IDifferentiable`. +/// static MyRayDifferential dadd(MyRayDifferential v1, MyRayDifferential v2) +/// { +/// MyRayDifferential result; +/// result.d_origin = v1.d_origin + v2.d_origin; +/// result.d_dir = v1.d_dir + v2.d_dir; +/// return result; +/// } +/// } +/// ``` +/// In this specific case, the automatically generated `IDifferentiable` +/// implementation will be exactly the same as the manually written code listed above. /// -/// @remarks Types that implement `IDifferentiable` can be used with the automatic differentiation -/// primitives `bwd_diff` and `fwd_diff` to load and store gradients of parameters. -/// @remarks This interface supports automatic synthesis of requirements. A struct that conforms to `IDifferentiable` -/// will have its `Differential`, `dzero()` and `dadd()` methods automatically synthesized based on its fields, if -/// they are not already defined. +/// __magic_type(DifferentiableType) interface IDifferentiable { @@ -455,8 +627,20 @@ interface IDifferentiable /// @experimental /// -/// Represents a type that supports differentiation operations for pointers, buffers and -/// any other types +/// The `IDifferentiablePtrType` interface requires the following definitions. +/// +/// ```csharp +/// interface IDifferentiablePtrType +/// { +/// associatedtype Differential : IDifferentiablePtrType +/// where Differential.Differential == Differential; +/// } +/// ``` +/// +/// Types that conform to this interface can be used with `DifferentialPtrPair` +/// to pass the derivative components to calls to `fwd_diff(fn)` or `bwd_diff(fn)` +/// +/// See the auto-diff user guide for more details (https://shader-slang.org/slang/user-guide/autodiff.html#differentiable-ptr-types) /// /// @remarks Support for this interface is still experimental and subject to change. /// @@ -468,9 +652,26 @@ interface IDifferentiablePtrType }; -/// Pair type that serves to wrap the primal and -/// differential types of a differentiable value type -/// T that conforms to `IDifferentiable`. +/// `DifferentialPair` is a built-in type that carries both the original and derivative value of a term. +/// It is defined as follows: +/// ```csharp +/// struct DifferentialPair : IDifferentiable +/// { +/// typealias Differential = DifferentialPair; +/// property T p {get;} +/// property T.Differential d {get;} +/// static Differential dzero(); +/// static Differential dadd(Differential a, Differential b); +/// } +/// ``` +/// +/// Differential pairs can be created via constructor or through the `diffPair()` operation +/// ```csharp +/// DifferentialPair dpa = DifferentialPair(1.0f, 2.0f); +/// DifferentialPair dpa = diffPair(1.0f, 2.0f); +/// ``` +/// Note that derivative pairs are used to pass derivatives into and out of auto-diff functions. +/// See documentation on `fwd_diff` and `bwd_diff` operators for more information. /// __generic __magic_type(DifferentialPairType) @@ -540,10 +741,36 @@ struct DifferentialPair : IDifferentiable } }; -/// Pair type that serves to wrap the primal and -/// differential types of a differentiable pointer type -/// T that conforms to `IDifferentiablePtrType`. +/// @experimental +/// `DifferentialPtrPair` is a built-in type that carries both the original and differential of a +/// pointer-like object. +/// `T` must conform to `IDifferentiablePtrType` +/// +/// It is defined as follows: +/// ```csharp +/// struct DifferentialPtrPair : IDifferentiablePtrType +/// { +/// typealias Differential = DifferentialPtrPair; +/// property T p {get;} +/// property T.Differential d {get;} +/// } +/// ``` +/// @remarks +/// Differential ptr pairs can be created via constructor. +/// ```csharp +/// struct DPtrFloat : IDifferentialPtrType +/// { +/// typealias Differential = DPtrFloat; +/// float* ptr; +/// }; +/// +/// DifferentialPtrPair dpa = +/// DifferentialPtrPair({&outputBuffer[0]}, {&outputBuffer[1]}); +/// ``` +/// Note that derivative ptr pairs are used to pass derivatives into and out of auto-diff functions. +/// See documentation on `fwd_diff` and `bwd_diff` operators for more information. /// +/// __generic __magic_type(DifferentialPtrPairType) __intrinsic_type($(kIROp_DifferentialPtrPairType)) diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 1200aef42..6f2bd2cd4 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1,31 +1,182 @@ -// Custom Forward Derivative Function reference +/// `[ForwardDerivative(fwdFn)]` attribute can be used to provide a forward-mode +/// derivative implementation. +/// Invoking `fwd_diff(decoratedFn)` will place a call to `fwdFn` instead of synthesizing +/// a derivative implementation. +/// The same behavior holds if `decoratedFn` is used in a differentiable context. +/// +/// +/// @remarks +/// - The signature of `fwdFn` must match the expected signature of `fwd_diff(decoratedFn)`. +/// - See the [reference](https://shader-slang.org/slang/user-guide/autodiff.html#fwd_difff--slang_function---slang_function) for `fwd_diff` for a full list of signature rules. +/// - See the [user guide's section](https://shader-slang.org/slang/user-guide/autodiff.html#user-defined-derivative-functions) on custom derivatives for an introduction to this approach. +/// +/// - This attribute can be used on generic functions, member functions and accessors. +/// - For generic functions, the generic signatures (parameters + constraints) of both functions must match exactly. +/// +/// - The decorated function will be considered forward-differentiable. There is no need for a `[Differentiable]` tag. +/// If the `[Differentiable]` tag is present, +/// and no custom backward derivative is specified with `[BackwardDerivative]`, then the +/// Slang will use auto-diff to generate the backward=mode derivative, but will use the provided +/// derivative for forward-mode. +/// +/// Example: +/// ```csharp +/// [ForwardDerivative(foo_fwd)] +/// T foo>(T x, P xarr) { /* ... */ } +/// +/// DifferentialPair foo_fwd>( // Use the same generic signature for a match. +/// DifferentialPair x, P dp_xarr) { /* ... */ } +/// ``` +/// +/// For member functions, or functions nested inside namespaces, `fwdFn` may need a fully qualified name. +/// +/// Example: +/// ```csharp +/// namespace A +/// { +/// DifferentialPair foo(DifferentialPair x) { /* ... */ } +/// } +/// [ForwardDerivative(A.foo)] // Use namespace and/or parent struct names +/// float bar(float x) { /* ... */ } +/// ``` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; +/// `[BackwardDerivative(bwdFn)]` attribute can be used to provide a forward-mode +/// derivative implementation. +/// Invoking `bwd_diff(decoratedFn)` will place a call to `bwdFn` instead of synthesizing +/// a derivative implementation. +/// The same behavior holds if `decoratedFn` is used in a differentiable context. +/// +/// @remarks +/// The signature of `bwdFn` must match the expected signature of `bwd_diff(decoratedFn)`. +/// See the [reference](https://shader-slang.org/slang/user-guide/autodiff.html#bwd_difff--slang_function---slang_function) +/// for `bwd_diff` for a full list of signature rules. +/// +/// See the [user guide's section](https://shader-slang.org/slang/user-guide/autodiff.html#user-defined-derivative-functions) on custom derivatives for an introduction to custom +/// derivatives. +/// +/// This attribute can be used on generic functions, member functions and accessors. +/// For generic functions, the generic signatures (parameters + constraints) of both functions +/// must match exactly. +/// Overloaded functions are also supported. The compiler will attempt to resolve the overload +/// from the expected derivative signature. If it is unable to do so, it will issue a +/// diagnostic error. +/// +/// The decorated function will be considered differentiable. +/// There is no need for a `[Differentiable]` tag. +/// +/// Example: +/// ```csharp +/// [BackwardDerivative(foo_bwd)] +/// T foo>(T x, P xarr) { /* ... */ } +/// +/// void foo_bwd>( // Use the same generic signature for a match. +/// inout DifferentialPair x, P dp_xarr, T.Differential dresult) { /* ... */ } +/// ``` +/// +/// For member functions, or functions nested inside namespaces, `bwdFn` may need to be a fully qualified +/// name. +/// +/// __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; +/// `[PrimalSubstitute(substFn)]` attribute denotes a substitute `substFn` that should be used for +/// differentiation instead of the original function. This serves as a sort of 'reference' implementation +/// where the original function cannot be differentiated (for whatever reason). +/// +/// See the auto-diff user guide for more: https://shader-slang.org/slang/user-guide/autodiff.html#primal-substitute-functions +/// +/// @example The following example shows in more detail on how primal substitute affects derivative computation. +/// ```csharp +/// [PrimalSubstitute(myFuncSubst)] +/// float myFunc(float x) { return x*x; } +/// +/// [Differentiable] +/// float myFuncSubst(float x) { return x*x*x; } +/// +/// [Differentiable] +/// float caller(float x) { return myFunc(x); } +/// +/// let a = caller(4.0); /// a == 16.0 (calling myFunc) +/// let b = fwd_diff(caller)(diffPair(4.0, 1.0)).p; /// b == 64.0 (calling myFuncSubst) +/// let c = fwd_diff(caller)(diffPair(4.0, 1.0)).d; /// c == 48.0 (calling derivative of myFuncSubst) +/// ``` +/// @remarks +/// `substFn` must have a function definition (i.e. a body). +/// +/// In case that a function has both custom defined derivatives and a differentiable +/// primal substitute, the primal substitute overrides the custom defined derivative +/// on the original function. All calls to the original function will be translated +/// into calls to the primal substitute first, and differentiation step follows after. +/// This means that the derivatives of the primal substitute function will be used instead +/// of the derivatives defined on the original function. +/// +/// This attribute can be used on generic functions and member functions. +/// For generic functions, the generic signatures (parameters + constraints) of both functions +/// must match exactly. +/// For member functions, or functions nested inside namespaces, `substFn` should be a fully qualified name. +/// `substFn` +/// Example: +/// ```csharp +/// namespace A +/// { +/// float foo(float x) { /* ... */ } +/// } +/// [PrimalSubstitute(A.foo)] // Use namespace and/or parent struct names +/// float bar(float x) { /* ... */ } +/// ``` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; +/// `[ForwardDerivativeOf(fn)]` is the back-reference version of `[ForwardDerivative(derivFn)]` +/// +/// When used to decorate a function, the decorated function is considered the forward-derivative +/// implementation of the referenced function `fn`. +/// +/// Apart from this, the semantics of the custom derivative are the same as for +/// `[ForwardDerivative(derivFn)]` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; +/// `[BackwardDerivativeOf(fn)]` is the back-reference version of `[BackwardDerivative(derivFn)]` +/// +/// When used to decorate a function, the decorated function is considered the backward-derivative +/// implementation of the referenced function `fn`. +/// +/// Apart from this, the semantics of the custom derivative are the same as for +/// `[BackwardDerivative(derivFn)]` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute; +/// `[PrimalSubstituteOf(fn)]` is the back-reference version of `[PrimalSubstitute(substFn)]` +/// +/// When used to decorate a function, that function is considered the substitute for the +/// referenced `fn`. +/// Apart from this difference, the semantics of the substitution are the same as for +/// `[PrimalSubstitute(substFn)]` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [PrimalSubstituteOf(function)] : PrimalSubstituteOfAttribute; __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; -// Exclude "this" parameter from differentiation. +// Exclude "this" parameter from differentiation. Effectively like putting a 'no_diff' on the +// "this" parameter. +// __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; -// A 'none-type' that acts as a run-time sentinel for zero differentials. +// A 'none-type' that acts as a run-time sentinel for zero differentials. This is primarily +// for internal use. +// [__AutoDiffBuiltin] export struct NullDifferential : IDifferentiable { -- cgit v1.2.3