diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-01-30 15:06:51 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-30 15:06:51 -0800 |
| commit | ae778e3424b39cbeb1f367339f654560de416d30 (patch) | |
| tree | dc8f9cd1d60853eff8e13ab82391f99ae4baa6ab /source | |
| parent | 21893763c18d323af56d30db1bf878abbbdb9060 (diff) | |
[Docs] Auto-diff documentation overhaul (#6202)
* AD: Docs Update
* More documentation
* More documentation
* More docs fixes
* Cleanup documentation
* More docs polish. Add docs for the [Differentiable] attributes
* Fixup code sections
* Fixup
* Address review comments
* regenerate documentation Table of Contents
* Update docs with more playground links
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 269 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 157 |
2 files changed, 402 insertions, 24 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index d4cce037d..36e9d6885 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -393,15 +393,47 @@ attribute_syntax [__NonCopyableType] : NonCopyableTypeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [__NoSideEffect] : NoSideEffectAttribute; -/// Marks a function for forward-mode differentiation. -/// i.e. the compiler will automatically generate a new function -/// that computes the jacobian-vector product of the original. +/// Marks a function as being differentiable in forward-mode. +/// +/// See the user guide [section on forward-mode differentiation](https://shader-slang.org/slang/user-guide/autodiff.html#invoking-auto-diff-in-slang) for more +/// +/// If used on a function that has a definition (i.e. a function body),Slang will use +/// automatic-differentiation to generate a forward-mode derivative of this function, +/// unless an implementation is provided by the user via `[ForwardDerivative(fn)]` +/// +/// If used on an interface requirement, the signature of the requirement is modified to +/// include forward-differentiability. Any satisfying method must also be forward-differentiable, +/// or provide an appropriate derivative implementation. +/// See the user guide [section on auto-diff for interfaces](https://shader-slang.org/slang/user-guide/autodiff.html##using-auto-diff-with-interface-requirements-and-interface-types) for more +/// __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; -/// Marks a function for backward-mode differentiation. +/// Marks a function as being differentiable for backward-mode auto-diff. +/// Note that in the current implementation, this implies that the method +/// is also forward differentiable. +/// +/// This attribute is equivalent to using `[Differentiable]` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable(order:int = 0)] : BackwardDifferentiableAttribute; + +/// Marks a function as being differentiable for both +/// forward and backward mode auto-diff. +/// +/// This attribute is equivalent to using `[Differentiable]` +/// +/// See the user guide [section on auto-diff invocation](https://shader-slang.org/slang/user-guide/autodiff.html#invoking-auto-diff-in-slang) for more. +/// +/// If used on a function that has a definition (i.e. a function body), Slang will use +/// automatic-differentiation to generate the derivative implementations for this function, +/// unless an implementation is provided by the user via `[ForwardDerivative(fn)]` and/or `[BackwardDerivative(fn)]` +/// +/// If used on an interface requirement, the signature of the requirement is modified to +/// include differentiability. Any satisfying method must also be differentiable, +/// or provide appropriate derivative implementations. +/// See the user guide [section on auto-diff for interfaces](https://shader-slang.org/slang/user-guide/autodiff.html##using-auto-diff-with-interface-requirements-and-interface-types) for more +/// __attributeTarget(FunctionDeclBase) attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribute; @@ -418,18 +450,158 @@ void __requireGLSLExtension(constexpr String preludeText); __intrinsic_op($(kIROp_StaticAssert)) void static_assert(constexpr bool condition, NativeString errorMessage); -/// Represents a type that is differentiable for the purposes of automatic differentiation. +/// Represents a 'value' type that is differentiable for the purposes of automatic differentiation. +/// +/// See the auto-diff user guide section for an introduction to +/// differentiable value types (https://shader-slang.org/slang/user-guide/autodiff.html#differentiable-value-types) +/// +/// #### Builtin Differentiable Value Types +/// The following built-in types are differentiable: +/// - Scalars: `float`, `double` and `half`. +/// - Vector/Matrix: `vector` and `matrix` of `float`, `double` and `half` types. +/// - Arrays: `T[n]` is differentiable if `T` is differentiable. +/// - Tuples: `Tuple<each T>` is differentiable if `T` is differentiable. +/// +/// The `IDifferentiable` interface requires the following definitions (which can be auto-generated by the compiler for most scenarios) +/// ```csharp +/// interface IDifferentiable +/// { +/// associatedtype Differential : IDifferentiable +/// where Differential.Differential == Differential; +/// +/// static Differential dzero(); +/// +/// static Differential dadd(Differential, Differential); +/// } +/// ``` +/// +/// As defined by the `IDifferentiable` interface, a differentiable type must have a +/// `Differential` associated type that stores the derivative of the value. +/// A further requirement is that the type of the second-order derivative must be the same +/// `Differential` type. In another words, given a type `T`, `T.Differential` can be different +/// from `T`, but `T.Differential.Differential` must equal to `T.Differential`. +/// +/// In addition, a differentiable type must define the `zero` value of its derivative, +/// and how to add two derivative values together. These function are used during reverse-mode +/// auto-diff, to initialize and accumulate derivatives of the given type. +/// +/// #### Automatic Fulfillment of `IDifferentiable` Requirements +/// Assume the user has defined the following type: +/// +/// ```csharp +/// struct MyRay +/// { +/// float3 origin; +/// float3 dir; +/// int nonDifferentiablePayload; +/// } +/// ``` +/// +/// The type can be made differentiable by adding `IDifferentiable` conformance: +/// ```csharp +/// struct MyRay : IDifferentiable +/// { +/// float3 origin; +/// float3 dir; +/// int nonDifferentiablePayload; +/// } +/// ``` +/// +/// Note that this code does not provide any explicit implementation of the `IDifferentiable` requirements. In this case the compiler will automatically synthesize all the requirements. This should provide the desired behavior most of the time. The procedure for synthesizing the interface implementation is as follows: +/// 1. A new type is generated that stores the `Differential` of all differentiable fields. This new type itself will conform to the `IDifferentiable` interface, and it will be used to satisfy the `Differential` associated type requirement. +/// 2. Each differential field will be associated to its corresponding field in the newly synthesized `Differential` type. +/// 3. The `zero` value of the differential type is made from the `zero` value of each field in the differential type. +/// 4. The `dadd` method invokes the `dadd` operations for each field whose type conforms to `IDifferentiable`. +/// 5. If the synthesized `Differential` type contains exactly the same fields as the original type, and the type of each field is the same as the original field type, then the original type itself will be used as the `Differential` type instead of creating a new type to satisfy the `Differential` associated type requirement. This means that all the synthesized `Differential` type use itself to meet its own `IDifferentiable` requirements. +/// +/// #### Manual fulfilment of `IDifferentiable` requirements +/// In rare cases where more control is desired, the user can manually provide the implementation. +/// To do so, we will first define the `Differential` type for `MyRay`, and use it to fulfill +/// the `Differential` requirement in `MyRay`: +/// ```csharp +/// struct MyRayDifferential +/// { +/// float3 d_origin; +/// float3 d_dir; +/// } +/// +/// struct MyRay : IDifferentiable +/// { +/// // Specify that `MyRay.Differential` is `MyRayDifferential`. +/// typealias Differential = MyRayDifferential; +/// +/// // Specify that the derivative for `origin` will be stored in `MayRayDifferential.d_origin`. +/// [DerivativeMember(MayRayDifferential.d_origin)] +/// float3 origin; +/// +/// // Specify that the derivative for `dir` will be stored in `MayRayDifferential.d_dir`. +/// [DerivativeMember(MayRayDifferential.d_dir)] +/// float3 dir; +/// +/// // This is a non-differentiable field so we don't put any attributes on it. +/// int nonDifferentiablePayload; +/// +/// // Define zero derivative. +/// static MyRayDifferential dzero() +/// { +/// return {float3(0.0), float3(0.0)}; +/// } +/// +/// // Define the add operation of two derivatives. +/// static MyRayDifferential dadd(MyRayDifferential v1, MyRayDifferential v2) +/// { +/// MyRayDifferential result; +/// result.d_origin = v1.d_origin + v2.d_origin; +/// result.d_dir = v1.d_dir + v2.d_dir; +/// return result; +/// } +/// } +/// ``` /// -/// Implemented by builtin floating-point scalar types (`float`, `half`, `double`) +/// Note that for each struct field that is differentiable, we need to use the `[DerivativeMember]` attribute to associate it with the +/// corresponding field in the `Differential` type, so the compiler knows how to access the derivative for the field. +/// +/// However, there is still a missing piece in the above code: we also need to make `MyRayDifferential` conform to `IDifferentiable` because it is required that the `Differential` of a type must itself be `Differential`. Again we can use automatic fulfillment by simply adding `IDifferentiable` conformance to `MyRayDifferential`: +/// ```csharp +/// struct MyRayDifferential : IDifferentiable +/// { +/// float3 d_origin; +/// float3 d_dir; +/// } +/// ``` +/// In this case, since all fields in `MyRayDifferential` are differentiable, and the `Differential` of each field is the same as the original type of each field (i.e. `float3.Differential == float3` as defined in the core module), the compiler will automatically use the type itself as its own `Differential`, making `MyRayDifferential` suitable for use as `Differential` of `MyRay`. +/// +/// We can also choose to manually implement `IDifferentiable` interface for `MyRayDifferential` as in the following code: +/// +/// ```csharp +/// struct MyRayDifferential : IDifferentiable +/// { +/// typealias Differential = MyRayDifferential; +/// +/// [DerivativeMember(MyRayDifferential.d_origin)] +/// float3 d_origin; +/// +/// [DerivativeMember(MyRayDifferential.d_dir)] +/// float3 d_dir; +/// +/// static MyRayDifferential dzero() +/// { +/// return {float3(0.0), float3(0.0)}; +/// } /// -/// vector<T, N>, matrix<T, N, M> and Array<T, N> automatically conform to -/// `IDifferentiable` if `T` conforms to `IDifferentiable`. +/// static MyRayDifferential dadd(MyRayDifferential v1, MyRayDifferential v2) +/// { +/// MyRayDifferential result; +/// result.d_origin = v1.d_origin + v2.d_origin; +/// result.d_dir = v1.d_dir + v2.d_dir; +/// return result; +/// } +/// } +/// ``` +/// In this specific case, the automatically generated `IDifferentiable` +/// implementation will be exactly the same as the manually written code listed above. /// -/// @remarks Types that implement `IDifferentiable` can be used with the automatic differentiation -/// primitives `bwd_diff` and `fwd_diff` to load and store gradients of parameters. -/// @remarks This interface supports automatic synthesis of requirements. A struct that conforms to `IDifferentiable` -/// will have its `Differential`, `dzero()` and `dadd()` methods automatically synthesized based on its fields, if -/// they are not already defined. +/// __magic_type(DifferentiableType) interface IDifferentiable { @@ -455,8 +627,20 @@ interface IDifferentiable /// @experimental /// -/// Represents a type that supports differentiation operations for pointers, buffers and -/// any other types +/// The `IDifferentiablePtrType` interface requires the following definitions. +/// +/// ```csharp +/// interface IDifferentiablePtrType +/// { +/// associatedtype Differential : IDifferentiablePtrType +/// where Differential.Differential == Differential; +/// } +/// ``` +/// +/// Types that conform to this interface can be used with `DifferentialPtrPair<T>` +/// to pass the derivative components to calls to `fwd_diff(fn)` or `bwd_diff(fn)` +/// +/// See the auto-diff user guide for more details (https://shader-slang.org/slang/user-guide/autodiff.html#differentiable-ptr-types) /// /// @remarks Support for this interface is still experimental and subject to change. /// @@ -468,9 +652,26 @@ interface IDifferentiablePtrType }; -/// Pair type that serves to wrap the primal and -/// differential types of a differentiable value type -/// T that conforms to `IDifferentiable`. +/// `DifferentialPair<T>` is a built-in type that carries both the original and derivative value of a term. +/// It is defined as follows: +/// ```csharp +/// struct DifferentialPair<T : IDifferentiable> : IDifferentiable +/// { +/// typealias Differential = DifferentialPair<T.Differential>; +/// property T p {get;} +/// property T.Differential d {get;} +/// static Differential dzero(); +/// static Differential dadd(Differential a, Differential b); +/// } +/// ``` +/// +/// Differential pairs can be created via constructor or through the `diffPair()` operation +/// ```csharp +/// DifferentialPair<float> dpa = DifferentialPair<float>(1.0f, 2.0f); +/// DifferentialPair<float> dpa = diffPair(1.0f, 2.0f); +/// ``` +/// Note that derivative pairs are used to pass derivatives into and out of auto-diff functions. +/// See documentation on `fwd_diff` and `bwd_diff` operators for more information. /// __generic<T : IDifferentiable> __magic_type(DifferentialPairType) @@ -540,10 +741,36 @@ struct DifferentialPair : IDifferentiable } }; -/// Pair type that serves to wrap the primal and -/// differential types of a differentiable pointer type -/// T that conforms to `IDifferentiablePtrType`. +/// @experimental +/// `DifferentialPtrPair<T>` is a built-in type that carries both the original and differential of a +/// pointer-like object. +/// `T` must conform to `IDifferentiablePtrType` +/// +/// It is defined as follows: +/// ```csharp +/// struct DifferentialPtrPair<T : IDifferentiablePtrType> : IDifferentiablePtrType +/// { +/// typealias Differential = DifferentialPtrPair<T.Differential>; +/// property T p {get;} +/// property T.Differential d {get;} +/// } +/// ``` +/// @remarks +/// Differential ptr pairs can be created via constructor. +/// ```csharp +/// struct DPtrFloat : IDifferentialPtrType +/// { +/// typealias Differential = DPtrFloat; +/// float* ptr; +/// }; +/// +/// DifferentialPtrPair<DPtrFloat> dpa = +/// DifferentialPtrPair<float>({&outputBuffer[0]}, {&outputBuffer[1]}); +/// ``` +/// Note that derivative ptr pairs are used to pass derivatives into and out of auto-diff functions. +/// See documentation on `fwd_diff` and `bwd_diff` operators for more information. /// +/// __generic<T : IDifferentiablePtrType> __magic_type(DifferentialPtrPairType) __intrinsic_type($(kIROp_DifferentialPtrPairType)) diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 1200aef42..6f2bd2cd4 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1,31 +1,182 @@ -// Custom Forward Derivative Function reference +/// `[ForwardDerivative(fwdFn)]` attribute can be used to provide a forward-mode +/// derivative implementation. +/// Invoking `fwd_diff(decoratedFn)` will place a call to `fwdFn` instead of synthesizing +/// a derivative implementation. +/// The same behavior holds if `decoratedFn` is used in a differentiable context. +/// +/// +/// @remarks +/// - The signature of `fwdFn` must match the expected signature of `fwd_diff(decoratedFn)`. +/// - See the [reference](https://shader-slang.org/slang/user-guide/autodiff.html#fwd_difff--slang_function---slang_function) for `fwd_diff` for a full list of signature rules. +/// - See the [user guide's section](https://shader-slang.org/slang/user-guide/autodiff.html#user-defined-derivative-functions) on custom derivatives for an introduction to this approach. +/// +/// - This attribute can be used on generic functions, member functions and accessors. +/// - For generic functions, the generic signatures (parameters + constraints) of both functions must match exactly. +/// +/// - The decorated function will be considered forward-differentiable. There is no need for a `[Differentiable]` tag. +/// If the `[Differentiable]` tag is present, +/// and no custom backward derivative is specified with `[BackwardDerivative]`, then the +/// Slang will use auto-diff to generate the backward=mode derivative, but will use the provided +/// derivative for forward-mode. +/// +/// Example: +/// ```csharp +/// [ForwardDerivative(foo_fwd)] +/// T foo<T : IFloat, P : IArray<T>>(T x, P xarr) { /* ... */ } +/// +/// DifferentialPair<T> foo_fwd<T : IFloat, P : IArray<T>>( // Use the same generic signature for a match. +/// DifferentialPair<T> x, P dp_xarr) { /* ... */ } +/// ``` +/// +/// For member functions, or functions nested inside namespaces, `fwdFn` may need a fully qualified name. +/// +/// Example: +/// ```csharp +/// namespace A +/// { +/// DifferentialPair<float> foo(DifferentialPair<float> x) { /* ... */ } +/// } +/// [ForwardDerivative(A.foo)] // Use namespace and/or parent struct names +/// float bar(float x) { /* ... */ } +/// ``` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; +/// `[BackwardDerivative(bwdFn)]` attribute can be used to provide a forward-mode +/// derivative implementation. +/// Invoking `bwd_diff(decoratedFn)` will place a call to `bwdFn` instead of synthesizing +/// a derivative implementation. +/// The same behavior holds if `decoratedFn` is used in a differentiable context. +/// +/// @remarks +/// The signature of `bwdFn` must match the expected signature of `bwd_diff(decoratedFn)`. +/// See the [reference](https://shader-slang.org/slang/user-guide/autodiff.html#bwd_difff--slang_function---slang_function) +/// for `bwd_diff` for a full list of signature rules. +/// +/// See the [user guide's section](https://shader-slang.org/slang/user-guide/autodiff.html#user-defined-derivative-functions) on custom derivatives for an introduction to custom +/// derivatives. +/// +/// This attribute can be used on generic functions, member functions and accessors. +/// For generic functions, the generic signatures (parameters + constraints) of both functions +/// must match exactly. +/// Overloaded functions are also supported. The compiler will attempt to resolve the overload +/// from the expected derivative signature. If it is unable to do so, it will issue a +/// diagnostic error. +/// +/// The decorated function will be considered differentiable. +/// There is no need for a `[Differentiable]` tag. +/// +/// Example: +/// ```csharp +/// [BackwardDerivative(foo_bwd)] +/// T foo<T : IFloat, P : IArray<T>>(T x, P xarr) { /* ... */ } +/// +/// void foo_bwd<T : IFloat, P : IArray<T>>( // Use the same generic signature for a match. +/// inout DifferentialPair<T> x, P dp_xarr, T.Differential dresult) { /* ... */ } +/// ``` +/// +/// For member functions, or functions nested inside namespaces, `bwdFn` may need to be a fully qualified +/// name. +/// +/// __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; +/// `[PrimalSubstitute(substFn)]` attribute denotes a substitute `substFn` that should be used for +/// differentiation instead of the original function. This serves as a sort of 'reference' implementation +/// where the original function cannot be differentiated (for whatever reason). +/// +/// See the auto-diff user guide for more: https://shader-slang.org/slang/user-guide/autodiff.html#primal-substitute-functions +/// +/// @example The following example shows in more detail on how primal substitute affects derivative computation. +/// ```csharp +/// [PrimalSubstitute(myFuncSubst)] +/// float myFunc(float x) { return x*x; } +/// +/// [Differentiable] +/// float myFuncSubst(float x) { return x*x*x; } +/// +/// [Differentiable] +/// float caller(float x) { return myFunc(x); } +/// +/// let a = caller(4.0); /// a == 16.0 (calling myFunc) +/// let b = fwd_diff(caller)(diffPair(4.0, 1.0)).p; /// b == 64.0 (calling myFuncSubst) +/// let c = fwd_diff(caller)(diffPair(4.0, 1.0)).d; /// c == 48.0 (calling derivative of myFuncSubst) +/// ``` +/// @remarks +/// `substFn` must have a function definition (i.e. a body). +/// +/// In case that a function has both custom defined derivatives and a differentiable +/// primal substitute, the primal substitute overrides the custom defined derivative +/// on the original function. All calls to the original function will be translated +/// into calls to the primal substitute first, and differentiation step follows after. +/// This means that the derivatives of the primal substitute function will be used instead +/// of the derivatives defined on the original function. +/// +/// This attribute can be used on generic functions and member functions. +/// For generic functions, the generic signatures (parameters + constraints) of both functions +/// must match exactly. +/// For member functions, or functions nested inside namespaces, `substFn` should be a fully qualified name. +/// `substFn` +/// Example: +/// ```csharp +/// namespace A +/// { +/// float foo(float x) { /* ... */ } +/// } +/// [PrimalSubstitute(A.foo)] // Use namespace and/or parent struct names +/// float bar(float x) { /* ... */ } +/// ``` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; +/// `[ForwardDerivativeOf(fn)]` is the back-reference version of `[ForwardDerivative(derivFn)]` +/// +/// When used to decorate a function, the decorated function is considered the forward-derivative +/// implementation of the referenced function `fn`. +/// +/// Apart from this, the semantics of the custom derivative are the same as for +/// `[ForwardDerivative(derivFn)]` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; +/// `[BackwardDerivativeOf(fn)]` is the back-reference version of `[BackwardDerivative(derivFn)]` +/// +/// When used to decorate a function, the decorated function is considered the backward-derivative +/// implementation of the referenced function `fn`. +/// +/// Apart from this, the semantics of the custom derivative are the same as for +/// `[BackwardDerivative(derivFn)]` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute; +/// `[PrimalSubstituteOf(fn)]` is the back-reference version of `[PrimalSubstitute(substFn)]` +/// +/// When used to decorate a function, that function is considered the substitute for the +/// referenced `fn`. +/// Apart from this difference, the semantics of the substitution are the same as for +/// `[PrimalSubstitute(substFn)]` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [PrimalSubstituteOf(function)] : PrimalSubstituteOfAttribute; __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; -// Exclude "this" parameter from differentiation. +// Exclude "this" parameter from differentiation. Effectively like putting a 'no_diff' on the +// "this" parameter. +// __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; -// A 'none-type' that acts as a run-time sentinel for zero differentials. +// A 'none-type' that acts as a run-time sentinel for zero differentials. This is primarily +// for internal use. +// [__AutoDiffBuiltin] export struct NullDifferential : IDifferentiable { |
