diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-01-30 15:06:51 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-30 15:06:51 -0800 |
| commit | ae778e3424b39cbeb1f367339f654560de416d30 (patch) | |
| tree | dc8f9cd1d60853eff8e13ab82391f99ae4baa6ab /source/slang/core.meta.slang | |
| parent | 21893763c18d323af56d30db1bf878abbbdb9060 (diff) | |
[Docs] Auto-diff documentation overhaul (#6202)
* AD: Docs Update
* More documentation
* More documentation
* More docs fixes
* Cleanup documentation
* More docs polish. Add docs for the [Differentiable] attributes
* Fixup code sections
* Fixup
* Address review comments
* regenerate documentation Table of Contents
* Update docs with more playground links
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/core.meta.slang')
| -rw-r--r-- | source/slang/core.meta.slang | 269 |
1 files changed, 248 insertions, 21 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index d4cce037d..36e9d6885 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -393,15 +393,47 @@ attribute_syntax [__NonCopyableType] : NonCopyableTypeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [__NoSideEffect] : NoSideEffectAttribute; -/// Marks a function for forward-mode differentiation. -/// i.e. the compiler will automatically generate a new function -/// that computes the jacobian-vector product of the original. +/// Marks a function as being differentiable in forward-mode. +/// +/// See the user guide [section on forward-mode differentiation](https://shader-slang.org/slang/user-guide/autodiff.html#invoking-auto-diff-in-slang) for more +/// +/// If used on a function that has a definition (i.e. a function body),Slang will use +/// automatic-differentiation to generate a forward-mode derivative of this function, +/// unless an implementation is provided by the user via `[ForwardDerivative(fn)]` +/// +/// If used on an interface requirement, the signature of the requirement is modified to +/// include forward-differentiability. Any satisfying method must also be forward-differentiable, +/// or provide an appropriate derivative implementation. +/// See the user guide [section on auto-diff for interfaces](https://shader-slang.org/slang/user-guide/autodiff.html##using-auto-diff-with-interface-requirements-and-interface-types) for more +/// __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; -/// Marks a function for backward-mode differentiation. +/// Marks a function as being differentiable for backward-mode auto-diff. +/// Note that in the current implementation, this implies that the method +/// is also forward differentiable. +/// +/// This attribute is equivalent to using `[Differentiable]` +/// __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable(order:int = 0)] : BackwardDifferentiableAttribute; + +/// Marks a function as being differentiable for both +/// forward and backward mode auto-diff. +/// +/// This attribute is equivalent to using `[Differentiable]` +/// +/// See the user guide [section on auto-diff invocation](https://shader-slang.org/slang/user-guide/autodiff.html#invoking-auto-diff-in-slang) for more. +/// +/// If used on a function that has a definition (i.e. a function body), Slang will use +/// automatic-differentiation to generate the derivative implementations for this function, +/// unless an implementation is provided by the user via `[ForwardDerivative(fn)]` and/or `[BackwardDerivative(fn)]` +/// +/// If used on an interface requirement, the signature of the requirement is modified to +/// include differentiability. Any satisfying method must also be differentiable, +/// or provide appropriate derivative implementations. +/// See the user guide [section on auto-diff for interfaces](https://shader-slang.org/slang/user-guide/autodiff.html##using-auto-diff-with-interface-requirements-and-interface-types) for more +/// __attributeTarget(FunctionDeclBase) attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribute; @@ -418,18 +450,158 @@ void __requireGLSLExtension(constexpr String preludeText); __intrinsic_op($(kIROp_StaticAssert)) void static_assert(constexpr bool condition, NativeString errorMessage); -/// Represents a type that is differentiable for the purposes of automatic differentiation. +/// Represents a 'value' type that is differentiable for the purposes of automatic differentiation. +/// +/// See the auto-diff user guide section for an introduction to +/// differentiable value types (https://shader-slang.org/slang/user-guide/autodiff.html#differentiable-value-types) +/// +/// #### Builtin Differentiable Value Types +/// The following built-in types are differentiable: +/// - Scalars: `float`, `double` and `half`. +/// - Vector/Matrix: `vector` and `matrix` of `float`, `double` and `half` types. +/// - Arrays: `T[n]` is differentiable if `T` is differentiable. +/// - Tuples: `Tuple<each T>` is differentiable if `T` is differentiable. +/// +/// The `IDifferentiable` interface requires the following definitions (which can be auto-generated by the compiler for most scenarios) +/// ```csharp +/// interface IDifferentiable +/// { +/// associatedtype Differential : IDifferentiable +/// where Differential.Differential == Differential; +/// +/// static Differential dzero(); +/// +/// static Differential dadd(Differential, Differential); +/// } +/// ``` +/// +/// As defined by the `IDifferentiable` interface, a differentiable type must have a +/// `Differential` associated type that stores the derivative of the value. +/// A further requirement is that the type of the second-order derivative must be the same +/// `Differential` type. In another words, given a type `T`, `T.Differential` can be different +/// from `T`, but `T.Differential.Differential` must equal to `T.Differential`. +/// +/// In addition, a differentiable type must define the `zero` value of its derivative, +/// and how to add two derivative values together. These function are used during reverse-mode +/// auto-diff, to initialize and accumulate derivatives of the given type. +/// +/// #### Automatic Fulfillment of `IDifferentiable` Requirements +/// Assume the user has defined the following type: +/// +/// ```csharp +/// struct MyRay +/// { +/// float3 origin; +/// float3 dir; +/// int nonDifferentiablePayload; +/// } +/// ``` +/// +/// The type can be made differentiable by adding `IDifferentiable` conformance: +/// ```csharp +/// struct MyRay : IDifferentiable +/// { +/// float3 origin; +/// float3 dir; +/// int nonDifferentiablePayload; +/// } +/// ``` +/// +/// Note that this code does not provide any explicit implementation of the `IDifferentiable` requirements. In this case the compiler will automatically synthesize all the requirements. This should provide the desired behavior most of the time. The procedure for synthesizing the interface implementation is as follows: +/// 1. A new type is generated that stores the `Differential` of all differentiable fields. This new type itself will conform to the `IDifferentiable` interface, and it will be used to satisfy the `Differential` associated type requirement. +/// 2. Each differential field will be associated to its corresponding field in the newly synthesized `Differential` type. +/// 3. The `zero` value of the differential type is made from the `zero` value of each field in the differential type. +/// 4. The `dadd` method invokes the `dadd` operations for each field whose type conforms to `IDifferentiable`. +/// 5. If the synthesized `Differential` type contains exactly the same fields as the original type, and the type of each field is the same as the original field type, then the original type itself will be used as the `Differential` type instead of creating a new type to satisfy the `Differential` associated type requirement. This means that all the synthesized `Differential` type use itself to meet its own `IDifferentiable` requirements. +/// +/// #### Manual fulfilment of `IDifferentiable` requirements +/// In rare cases where more control is desired, the user can manually provide the implementation. +/// To do so, we will first define the `Differential` type for `MyRay`, and use it to fulfill +/// the `Differential` requirement in `MyRay`: +/// ```csharp +/// struct MyRayDifferential +/// { +/// float3 d_origin; +/// float3 d_dir; +/// } +/// +/// struct MyRay : IDifferentiable +/// { +/// // Specify that `MyRay.Differential` is `MyRayDifferential`. +/// typealias Differential = MyRayDifferential; +/// +/// // Specify that the derivative for `origin` will be stored in `MayRayDifferential.d_origin`. +/// [DerivativeMember(MayRayDifferential.d_origin)] +/// float3 origin; +/// +/// // Specify that the derivative for `dir` will be stored in `MayRayDifferential.d_dir`. +/// [DerivativeMember(MayRayDifferential.d_dir)] +/// float3 dir; +/// +/// // This is a non-differentiable field so we don't put any attributes on it. +/// int nonDifferentiablePayload; +/// +/// // Define zero derivative. +/// static MyRayDifferential dzero() +/// { +/// return {float3(0.0), float3(0.0)}; +/// } +/// +/// // Define the add operation of two derivatives. +/// static MyRayDifferential dadd(MyRayDifferential v1, MyRayDifferential v2) +/// { +/// MyRayDifferential result; +/// result.d_origin = v1.d_origin + v2.d_origin; +/// result.d_dir = v1.d_dir + v2.d_dir; +/// return result; +/// } +/// } +/// ``` /// -/// Implemented by builtin floating-point scalar types (`float`, `half`, `double`) +/// Note that for each struct field that is differentiable, we need to use the `[DerivativeMember]` attribute to associate it with the +/// corresponding field in the `Differential` type, so the compiler knows how to access the derivative for the field. +/// +/// However, there is still a missing piece in the above code: we also need to make `MyRayDifferential` conform to `IDifferentiable` because it is required that the `Differential` of a type must itself be `Differential`. Again we can use automatic fulfillment by simply adding `IDifferentiable` conformance to `MyRayDifferential`: +/// ```csharp +/// struct MyRayDifferential : IDifferentiable +/// { +/// float3 d_origin; +/// float3 d_dir; +/// } +/// ``` +/// In this case, since all fields in `MyRayDifferential` are differentiable, and the `Differential` of each field is the same as the original type of each field (i.e. `float3.Differential == float3` as defined in the core module), the compiler will automatically use the type itself as its own `Differential`, making `MyRayDifferential` suitable for use as `Differential` of `MyRay`. +/// +/// We can also choose to manually implement `IDifferentiable` interface for `MyRayDifferential` as in the following code: +/// +/// ```csharp +/// struct MyRayDifferential : IDifferentiable +/// { +/// typealias Differential = MyRayDifferential; +/// +/// [DerivativeMember(MyRayDifferential.d_origin)] +/// float3 d_origin; +/// +/// [DerivativeMember(MyRayDifferential.d_dir)] +/// float3 d_dir; +/// +/// static MyRayDifferential dzero() +/// { +/// return {float3(0.0), float3(0.0)}; +/// } /// -/// vector<T, N>, matrix<T, N, M> and Array<T, N> automatically conform to -/// `IDifferentiable` if `T` conforms to `IDifferentiable`. +/// static MyRayDifferential dadd(MyRayDifferential v1, MyRayDifferential v2) +/// { +/// MyRayDifferential result; +/// result.d_origin = v1.d_origin + v2.d_origin; +/// result.d_dir = v1.d_dir + v2.d_dir; +/// return result; +/// } +/// } +/// ``` +/// In this specific case, the automatically generated `IDifferentiable` +/// implementation will be exactly the same as the manually written code listed above. /// -/// @remarks Types that implement `IDifferentiable` can be used with the automatic differentiation -/// primitives `bwd_diff` and `fwd_diff` to load and store gradients of parameters. -/// @remarks This interface supports automatic synthesis of requirements. A struct that conforms to `IDifferentiable` -/// will have its `Differential`, `dzero()` and `dadd()` methods automatically synthesized based on its fields, if -/// they are not already defined. +/// __magic_type(DifferentiableType) interface IDifferentiable { @@ -455,8 +627,20 @@ interface IDifferentiable /// @experimental /// -/// Represents a type that supports differentiation operations for pointers, buffers and -/// any other types +/// The `IDifferentiablePtrType` interface requires the following definitions. +/// +/// ```csharp +/// interface IDifferentiablePtrType +/// { +/// associatedtype Differential : IDifferentiablePtrType +/// where Differential.Differential == Differential; +/// } +/// ``` +/// +/// Types that conform to this interface can be used with `DifferentialPtrPair<T>` +/// to pass the derivative components to calls to `fwd_diff(fn)` or `bwd_diff(fn)` +/// +/// See the auto-diff user guide for more details (https://shader-slang.org/slang/user-guide/autodiff.html#differentiable-ptr-types) /// /// @remarks Support for this interface is still experimental and subject to change. /// @@ -468,9 +652,26 @@ interface IDifferentiablePtrType }; -/// Pair type that serves to wrap the primal and -/// differential types of a differentiable value type -/// T that conforms to `IDifferentiable`. +/// `DifferentialPair<T>` is a built-in type that carries both the original and derivative value of a term. +/// It is defined as follows: +/// ```csharp +/// struct DifferentialPair<T : IDifferentiable> : IDifferentiable +/// { +/// typealias Differential = DifferentialPair<T.Differential>; +/// property T p {get;} +/// property T.Differential d {get;} +/// static Differential dzero(); +/// static Differential dadd(Differential a, Differential b); +/// } +/// ``` +/// +/// Differential pairs can be created via constructor or through the `diffPair()` operation +/// ```csharp +/// DifferentialPair<float> dpa = DifferentialPair<float>(1.0f, 2.0f); +/// DifferentialPair<float> dpa = diffPair(1.0f, 2.0f); +/// ``` +/// Note that derivative pairs are used to pass derivatives into and out of auto-diff functions. +/// See documentation on `fwd_diff` and `bwd_diff` operators for more information. /// __generic<T : IDifferentiable> __magic_type(DifferentialPairType) @@ -540,10 +741,36 @@ struct DifferentialPair : IDifferentiable } }; -/// Pair type that serves to wrap the primal and -/// differential types of a differentiable pointer type -/// T that conforms to `IDifferentiablePtrType`. +/// @experimental +/// `DifferentialPtrPair<T>` is a built-in type that carries both the original and differential of a +/// pointer-like object. +/// `T` must conform to `IDifferentiablePtrType` +/// +/// It is defined as follows: +/// ```csharp +/// struct DifferentialPtrPair<T : IDifferentiablePtrType> : IDifferentiablePtrType +/// { +/// typealias Differential = DifferentialPtrPair<T.Differential>; +/// property T p {get;} +/// property T.Differential d {get;} +/// } +/// ``` +/// @remarks +/// Differential ptr pairs can be created via constructor. +/// ```csharp +/// struct DPtrFloat : IDifferentialPtrType +/// { +/// typealias Differential = DPtrFloat; +/// float* ptr; +/// }; +/// +/// DifferentialPtrPair<DPtrFloat> dpa = +/// DifferentialPtrPair<float>({&outputBuffer[0]}, {&outputBuffer[1]}); +/// ``` +/// Note that derivative ptr pairs are used to pass derivatives into and out of auto-diff functions. +/// See documentation on `fwd_diff` and `bwd_diff` operators for more information. /// +/// __generic<T : IDifferentiablePtrType> __magic_type(DifferentialPtrPairType) __intrinsic_type($(kIROp_DifferentialPtrPairType)) |
