summaryrefslogtreecommitdiffstats
path: root/docs
diff options
context:
space:
mode:
Diffstat (limited to 'docs')
-rw-r--r--docs/design/autodiff/decorators.md92
-rw-r--r--docs/design/autodiff/ir-overview.md2
-rw-r--r--docs/design/autodiff/types.md290
3 files changed, 383 insertions, 1 deletions
diff --git a/docs/design/autodiff/decorators.md b/docs/design/autodiff/decorators.md
new file mode 100644
index 000000000..b08da2912
--- /dev/null
+++ b/docs/design/autodiff/decorators.md
@@ -0,0 +1,92 @@
+This document details auto-diff-related decorations that are lowered in to the IR to help annotate methods with relevant information.
+
+## `[Differentiable]`
+The `[Differentiable]` attribute is used to mark functions as being differentiable. The auto-diff process will only touch functions that are marked explicitly as `[Differentiable]`. All other functions are considered non-differentiable and calls to such functions from a differentiable function are simply copied as-is with no transformation.
+
+Further, only `[Differentiable]` methods are checked during the derivative data-flow pass. This decorator is translated into `BackwardDifferentiableAttribute` (which implies both forward and backward differentiability), and then lowered into the IR `OpBackwardDifferentiableDecoration`
+
+**Note:** `[Differentiable]` was previously implemented as two separate decorators `[ForwardDifferentiable]` and `[BackwardDifferentiable]` to denote differentiability with each type of auto-diff transformation. However, these are now **deprecated**. The preferred approach is to use only `[Differentiable]`
+
+`fwd_diff` and `bwd_diff` cannot be directly called on methods that don't have the `[Differentiable]` tag (will result in an error). If non-`[Differentiable]` methods are called from within a `[Differentiable]` method, they must be wrapped in `no_diff()` operation (enforced by the [derivative data-flow analysis pass](./types.md#derivative-data-flow-analysis) )
+
+### `[Differentiable]` for `interface` Requirements
+The `[Differentiable]` attribute can also be used to decorate interface requirements. In this case, the attribute is handled in a slightly different manner, since we do not have access to the concrete implementations.
+
+The process is roughly as follows:
+1. During the semantic checking step, when checking a method that is an interface requirement (in `checkCallableDeclCommon` in `slang-check-decl.cpp`), we check if the method has a `[Differentiable]` attribute
+2. If yes, we construct create a set of new method declarations, one for the forward-mode derivative (`ForwardDerivativeRequirementDecl`) and one for the reverse-mode derivative (`BackwardDerivativeRequirementDecl`), with the appropriate translated function types and insert them into the same interface.
+3. Insert a new member into the original method to reference the new declarations (`DerivativeRequirementReferenceDecl`)
+4. When lowering to IR, the `DerivativeRequirementReferenceDecl` member is converted into a custom derivative reference by adding the `OpBackwardDerivativeDecoration(deriv-fn-req-key)` and `OpForwardDerivativeDecoration(deriv-fn-req-key)` decorations on the primal method's requirement key.
+
+Here is an example of what this would look like:
+
+```C
+interface IFoo
+{
+ [Differentiable]
+ float bar(float);
+};
+
+// After checking & lowering
+interface IFoo_after_checking_and_lowering
+{
+ [BackwardDerivative(bar_bwd)]
+ [ForwardDerivative(bar_fwd)]
+ float bar(float);
+
+ void bar_bwd(inout DifferentialPair<float>, float);
+
+ DifferentialPair<float> bar_fwd(DifferentialPair<float>);
+};
+```
+
+**Note:** All conforming types must _also_ declare their corresponding implementations as differentiable so that their derivative implementations are synthesized to match the interface signature. In this sense, the `[Differentiable]` attribute is part of the functions signature, so a `[Differentiable]` interface requirement can only be satisfied by a `[Differentiable]` function implementation
+
+### `[TreatAsDifferentiable]`
+In large codebases where some interfaces may have several possible implementations, it may not be reasonable to have to mark all possible implementations with `[Differentiable]`, especially if certain implementations use hacks or workarounds that need additional consideration before they can be marked `[Differentiable]`
+
+In such cases, we provide the `[TreatAsDifferentiable]` decoration (AST node: `TreatAsDifferentiableAttribute`, IR: `OpTreatAsDifferentiableDecoration`), which instructs the auto-diff passes to construct an 'empty' function that returns a 0 (or 0-equivalent) for the derivative values. This allows the signature of a `[TreatAsDifferentiable]` function to match a `[Differentiable]` requirment without actually having to produce a derivative.
+
+## Custom derivative decorators
+In many cases, it is desirable to manually specify the derivative code for a method rather than let the auto-diff pass synthesize it from the method body. This is usually desirable if:
+1. The body of the method is too complex, and there is a simpler, mathematically equivalent way to compute the same value (often the case for intrinsics like `sin(x)`, `arccos(x)`, etc..)
+2. The method involves global/shared memory accesses, and synthesized derivative code may cause race conditions or be very slow due to overuse of synchronization. For this reason Slang assumes global memory accesses are non-differentiable by default, and requires that the user (or stdlib) define separate accessors with different derivative semantics.
+
+The Slang front-end provides two sets of decorators to facilitate this:
+1. To reference a custom derivative function from a primal function: `[ForwardDerivative(fn)]` and `[BackwardDerivative(fn)]` (AST Nodes: `ForwardDerivativeAttribute`/`BackwardDerivativeAttribute`, IR: `OpForwardDervativeDecoration`/`OpBackwardDerivativeDecoration`), and
+2. To reference a primal function from its custom derivative function: `[ForwardDerivativeOf(fn)]` and `[BackwardDerivativeOf(fn)]` (AST Nodes: `ForwardDerivativeAttributeOf`/`BackwardDerivativeAttributeOf`). These attributes are useful to provide custom derivatives for existing methods in a different file without having to edit/change that module. For instance, we use `diff.meta.slang` to provide derivatives for stdlib functions in `hlsl.meta.slang`. When lowering to IR, these references are placed on the target (primal function). That way both sets of decorations are lowered on the primal function.
+
+These decorators also work on generically defined methods, as well as struct methods. Similar to how function calls work, these decorators also work on overloaded methods (and reuse the `ResolveInoke` infrastructure to perform resolution)
+
+### Checking custom derivative signatures
+To ensure that the user-provided derivatives agree with the expected signature, as well as resolve the appropriate method when multiple overloads are available, we check the signature of the custom derivative function against the translated version of the primal function. This currently occurs in `checkDerivativeAttribute()`/`checkDerivativeOfAttribute()`.
+
+The checking process re-uses existing infrastructure from `ResolveInvoke`, by constructing a temporary invoke expr to call the user-provided derivative using a set of 'imaginary' arguments according to the translated type of the primal method. If `ResolveInvoke` is successful, the provided derivative signature is considered to be a match. This approach also automatically allows us to resolve overloaded methods, account for generic types and type coercion.
+
+## `[PrimalSubstitute(fn)]` and `[PrimalSubstituteOf(fn)]`
+In some cases, we face the opposite problem that inspired custom derivatives. That is, we want the compiler to auto-synthesize the derivative from the function body, but there _is_ no function body to translate.
+This frequently occurs with hardware intrinsic operations that are lowered into special op-codes that map to hardware units, such as texture sampling & interpolation operations.
+However, these operations do have reference 'software' implementations which can be used to produce the derivative.
+
+To allow user code to use the fast hardward intrinsics for the primal pass, but use synthesized derivatives for the derivative pass, we provide decorators `[PrimalSubstitute(ref-fn)]` and `[PrimalSubstituteOf(orig-fn)]` (AST Node: `PrimalSubstituteAttribute`/`PrimalSubstituteOfAttribute`, IR: `OpPrimalSubstituteDecoration`), that can be used to provide a reference implementation for the auto-diff pass.
+
+Example:
+```C
+[PrimalSubstitute(sampleTexture_ref)]
+float sampleTexture(TexHandle2D tex, float2 uv)
+{
+ // Hardware intrinsics
+}
+
+float sampleTexture_ref(TexHandle2D tex, float2 uv)
+{
+ // Reference SW implementation.
+}
+
+void sampleTexture_bwd(TexHandle2D tex, inout DifferentialPair<float2> dp_uv, float dOut)
+{
+ // Backward derivate code synthesized using the reference implementation.
+}
+```
+
+The implementation of `[PrimalSubstitute(fn)]` is relatively straightforward. When the transcribers are asked to synthesize a derivative of a function, they check for a `OpPrimalSubstituteDecoration`, and swap the current function out for the substitute function before proceeding with derivative synthesis. \ No newline at end of file
diff --git a/docs/design/autodiff/ir-overview.md b/docs/design/autodiff/ir-overview.md
index 2dd13b376..a6b3ec207 100644
--- a/docs/design/autodiff/ir-overview.md
+++ b/docs/design/autodiff/ir-overview.md
@@ -1,7 +1,7 @@
This documentation is intended for Slang contributors and is written from a compiler engineering point of view. For Slang users, see the user-guide at this link: [https://shader-slang.com/slang/user-guide/autodiff.html](https://shader-slang.com/slang/user-guide/autodiff.html)
# Overview of Automatic Differentiation's IR Passes
-In this document we will detail how Slang's auto-diff passes generate valid forward-mode and reverse-mode derivative functions. Refer to [Basics](./basics.md) for a review of two derivative propagation methods and their mathematical connotations & [Types](./types.md) for a review of how different types transform under differentiation.
+In this document we will detail how Slang's auto-diff passes generate valid forward-mode and reverse-mode derivative functions. Refer to [Basics](./basics.md) for a review of the two derivative propagation methods and their mathematical connotations & [Types](./types.md) for a review of how types are handled under differentiation.
## Auto-Diff Pass Invocation
Note that without an explicit auto-diff instruction (`fwd_diff(fn)` or `bwd_diff(fn)`) from the user present anywhere in the code, none of the auto-diff passes will do anything.
diff --git a/docs/design/autodiff/types.md b/docs/design/autodiff/types.md
new file mode 100644
index 000000000..12b468a6e
--- /dev/null
+++ b/docs/design/autodiff/types.md
@@ -0,0 +1,290 @@
+
+This documentation is intended for Slang contributors and is written from a compiler engineering point of view. For Slang users, see the user-guide at this link: [https://shader-slang.com/slang/user-guide/autodiff.html](https://shader-slang.com/slang/user-guide/autodiff.html)
+
+Before diving into this document, please review the document on [Basics](./basics.md) for the fundamentals of automatic differentiation.
+
+# Components of the Type System
+Here we detail the main components of the type system: the `IDifferentiable` interface to define differentiable types, the `DifferentialPair<T>` type to carry a primal and corresponding differential in a single type.
+We also detail how auto-diff operators are type-checked (the higher-order function checking system), how the `no_diff` decoration can be used to avoid differentiation through attributed types, and the derivative data flow analysis that warns the the user of unintentionally stopping derivatives.
+
+## `interface IDifferentiable`
+Defined in core.meta.slang, `IDifferentiable` forms the basis for denoting differentiable types, both within the stdlib, and otherwise.
+The definition of `IDifferentiable` is designed to encapsulate the following 4 items:
+1. `Differential`: The type of the differential value of the conforming type. This allows custom data-structures to be defined to carry the differential values, which may be optimized for space instead of relying solely on compiler synthesis/
+
+Since the computation of derivatives is inherently linear, we only need access to a few operations. These are:
+
+2. `dadd(Differential, Differential) -> Differential`: Addition of two values of the differential type. It's implementation must be associative and commutative, or the resulting derivative code may be incorrect.
+3. `dzero() -> Differential`: Additive identity (i.e. the zero or empty value) that can be used to initialize variables during gradient aggregation
+4. `dmul<S:__BuiltinRealType>(S, Differential)`: Scalar multiplication of a real number with the differential type. It's implementation must be distributive over differential addition (`dadd`).
+
+Points 2, 3 & 4 are derived from the concept of vector spaces. The derivative values of any Slang function always form a vector space (https://en.wikipedia.org/wiki/Vector_space).
+
+### Derivative member associations
+In certain scenarios, the compiler needs information on how the fields in the original type map to the differential type. Particularly, this is a problem when differentiate the implicit construction of a struct through braces (i.e. `{}`), represented by `kIROp_MakeStruct`. We provide the decorator `[DerivativeMember(DifferentialTypeName.fieldName)]` (ASTNode: DerivativeMemberAttribute, IR: kIROp_DerivativeMemberDecoration) to explicitly mark these associations.
+Example
+```C
+struct MyType : IDifferentiable
+{
+ typealias Differential = MyDiffType;
+ float a;
+
+ [DerivativeMember(MyDiffType.db)]
+ float b;
+
+ /* ... */
+};
+
+struct MyDiffType
+{
+ float db;
+};
+```
+
+### Automatic Synthesis of `IDifferentible` Conformances for Aggregate Types
+It can be tedious to expect users to hand-write the associated `Differential` type, the corresponding mappings and interface methods for every user-defined `struct` type. For aggregate types, these are trivial to construct by analysing which of their components conform to `IDifferentiable`.
+The synthesis proceeds in roughly the following fashion:
+1. `IDifferentiable`'s components are tagged with special decorator `__builtin_requirement(unique_integer_id)` which carries an enum value from `BuiltinRequirementKind`.
+2. When checking that types conform to their interfaces, if a user-provided definition does not satisfy a requirement with a built-in tag, we perform synthesis by dispatching to `trySynthesizeRequirementWitness`.
+3. For _user-defined types_, Differential **types** are synthesized during conformance-checking through `trySynthesizeDifferentialAssociatedTypeRequirementWitness` by checking if each constituent type conforms to `IDifferentiable`, looking up the corresponding `Differential` type, and constructing a new aggregate type from these differential types. Note that since it is possible that a `Differential` type of a constituent member has not yet been synthesized, we have additional logic in the lookup system (`trySynthesizeRequirementWitness`) that synthesizes a temporary empty type with a `ToBeSynthesizedModifier`, so that the fields can be filled in later, when the member type undergoes conformance checking.
+4. For _user-defined types_, Differential methods (`dadd`, `dzero` and `dmul`) are synthesized in `trySynthesizeDifferentialMethodRequirementWitness` by utilizing the `Differential` member and its `[DifferentialMember]` decorations to determine which fields need to be considered and the base type to use for each field. There are two synthesis patterns. The fully-inductive pattern is used for `dadd` and `dzero` which works by calling `dadd` and `dzero` respectively on the individual fields of the `Differential` type under consideration.
+Example:
+```C
+// Synthesized from "struct T {FT1 field1; FT2 field2;}"
+T.Differential dadd(T.Differential a, T.Differential b)
+{
+ return Differential(
+ FT1.dadd(a.field1, b.field1),
+ FT2.dadd(a.field2, b.field2),
+ )
+}
+```
+On the other hand, `dmul` uses the fixed-first arg pattern since the first argument is a common scalar, and proceeds inductively on all the other args.
+Example:
+```C
+// Synthesized from "struct T {FT1 field1; FT2 field2;}"
+T.Differential dmul<S:__BuiltinRealType>(S s, T.Differential a)
+{
+ return Differential(
+ FT1<S>.dmul(s, a.field1),
+ FT2<S>.dmul(s, a.field2),
+ )
+}
+```
+5. During auto-diff, the compiler can sometimes synthesize new aggregate types. The most common case is the intermediate context type (`kIROp_BackwardDerivativeIntermediateContextType`), which is lowered into a standard struct once the auto-diff pass is complete. It is important to synthesize the `IDifferentiable` conformance for such types since they may be further differentiated (through higher-order differentiation). This implementation is contained in `fillDifferentialTypeImplementationForStruct(...)` and is roughly analogous to the AST-side synthesis.
+
+### Differentiable Type Dictionaries
+During auto-diff, the IR passes frequently need to perform lookups to check if an `IRType` is differentiable, and retreive references to the corresponding `IDifferentiable` methods. These lookups also need to work on generic parameters (that are defined inside generic containers), and existential types that are interface-typed parameters.
+
+To accomodate this range of different type systems, Slang uses a type dictionary system that associates a dictionary of relevant types with each function. This works in the following way:
+1. When `CheckTerm()` is called on an expression within a function that is marked differentiable (`[Differentiable]`), we check if the resolved type conforms to `IDifferentiable`. If so, we add this type to the dictionary along with the witness to its differentiability. The dictionary is currently located on `DifferentiableAttribute` that corresponds to the `[Differentiable]` modifier.
+
+2. When lowering to IR, we create a `DifferentiableTypeDictionaryDecoration` which holds the IR versions of all the types in the dictionary as well as a reference to their `IDifferentiable` witness tables.
+
+3. When synthesizing the derivative code, all the transcriber passes use `DifferentiableTypeConformanceContext::setFunc()` to load the type dictionary. `DifferentiableTypeConformanceContext` then provides convenience functions to lookup differentiable types, appropriate `IDifferentiable` methods, and construct appropriate `DifferentialPair<T>`s.
+
+### Looking up Differential Info on _Generic_ types
+Generically defined types are also lowered into the differentiable type dictionary, but rather than having a concrete witness table, the witness table is itself a parameter. When auto-diff passes need to find the differential type or place a call to the IDifferentiable methods, this is turned into a lookup on the witness table parameter (i.e. `Lookup(<InterfaceRequirementKey>, <WitnessTableParameter>)`). Note that these lookups instructions are inserted into the generic parent container rather than the inner most function.
+Example:
+```C
+T myFunc<T:IDifferentiable>(T a)
+{
+ return a * a;
+}
+
+// Reverse-mode differentiated version
+void bwd_myFunc<T:IDifferentiable>(
+ inout DifferentialPair<T> dpa,
+ T.Differential dOut) // T.Differential is Lookup('Differential', T_Witness_Table)
+{
+ T.Differential da = T.dzero(); // T.dzero is Lookup('dzero', T_Witness_Table)
+
+ da = T.dadd(dpa.p * dOut, da); // T.dadd is Lookup('dadd', T_Witness_Table)
+ da = T.dadd(dpa.p * dOut, da);
+
+ dpa = diffPair(dpa.p, da);
+}
+```
+
+### Looking up Differential Info on _Existential_ types
+Existential types are interface-typed values, where there are multiple possible implementations at run-time. The existential type carries information about the concrete type at run-time and is effectively a 'tagged union' of all possible types.
+
+#### Differential type of an Existential
+The differential type of an existential type is tricky to define since our type system's only restriction on the `.Differential` type is that it also conforms to `IDifferentiable`. The differential type of any interface `IInterface : IDifferentiable` is therefore the interface type `IDifferentiable`. This is problematic since Slang generally requires a static `anyValueSize` that must be a strict upper bound on the sizes of all conforming types (since this size is used to allocate space for the union). Since `IDifferentiable` is defined in the stdlib `core.meta.slang` and can be used by the user, it is impossible to define a reliable bound.
+We instead provide a new **any-value-size inference** pass (`slang-ir-any-value-inference.h`/`slang-ir-any-value-inference.cpp`) that assembles a list of types that conform to each interface in the final linked IR and determines a relevant upper bound. This allows us to ignore types that conform to `IDifferentiable` but aren't used in the final IR, and generate a tighter upper bound.
+
+**Future work:**
+This approach, while functional, creates a locality problem since the size of `IDifferentiable` is the max of _all_ types that conform to `IDifferentiable` in visible modules, even though we only care about the subset of types that appear as `T.Differential` for `T : IInterface`. The reason for this problem is that upon performing an associated type lookup, the Slang IR drops all information about the base interface that the lookup starts from and only considers the constraint interface (in this case `Differential : IDifferentiable`).
+There are several ways to resolve this issue, including (i) a static analysis pass that determines the possible set of types at each use location and propagates them to determine a narrower set of types, or (ii) generic (or 'parameterized') interfaces, such as `IDifferentiable<T>` where each version can have a different set of conforming types.
+
+<!--#### IDifferentiable Method lookups on an Existential
+All other method lookups are performed using existential-type lookups on the existential parameter. The idea is that existential-typed parameters come with a witness-table component that can be accessed by invoking `kIROp_ExtractExistentialWitnessTable` on them. This allows us to look up the `dadd`/`dzero` methods on this witness table in the same way as we did for generic types.-->
+
+Example:
+```C
+interface IInterface : IDifferentiable
+{
+ [Differentiable]
+ This foo(float val);
+
+ [Differentiable]
+ float bar();
+};
+
+float myFunc(IInterface obj, float a)
+{
+ IInterface k = obj.foo(a);
+ return k.bar();
+}
+
+// Reverse-mode differentiated version (in pseudo-code corresponding to IR, some of these will get lowered further)
+void bwd_myFunc(
+ inout DifferentialPair<IInterface> dpobj,
+ inout DifferentialPair<float> dpa,
+ float.Differential dOut) // T.Differential is Lookup('Differential', T_Witness_Table)
+{
+ // Primal pass..
+ IInterface obj = dpobj.p;
+ IInterface k = obj.foo(a);
+ // .....
+
+ // Backward pass
+ DifferentialPair<IInterface> dpk = diffPair(k);
+ bwd_bar(dpk, dOut);
+ IDifferentiable dk = dpk.d; // Differential of `IInterface` is `IDifferentiable`
+
+ DifferentialPair<IInterface> dp = diffPair(dpobj.p);
+ bwd_foo(dpobj, dpa, dk);
+}
+
+```
+
+#### Looking up `dadd()` and `dzero()` on Existential Types
+There are two distinct cases for lookup on an existential type. The more common case is the closed-box existential type represented simply by an interface. Every value of this type contains a type identifier & a witness table identifier along with the value itself. The less common case is when the function calls are performed directly on the value after being cast to the concrete type.
+
+**`dzero()` for "closed" Existential type: The `NullDifferential` Type**
+For concrete and even generic types, we can initialize a derivative accumulator variable by calling the appropriate `Type.dzero()` method. This is unfortunately not possible when initializing an existential differential (which is currently of type `IDifferentiable`), since we must also initialize the type-id of this existential to one of the implementations, but we do not know which one yet since that is a run-time value that only becomes known after the first differential value is generated.
+
+To get around this issue, we declare a special type called `NullDifferential` that acts as a "none type" for any `IDifferentiable` existential object.
+
+**`dadd()` for "closed" Existential types: `__existential_dadd`**
+We cannot directly use `dadd()` on two existential differentials of type `IDifferentiable` because we must handle the case where one of them is of type `NullDifferential` and `dadd()` is only defined for differentials of the same type.
+We handle this currently by synthesizing a special method called `__existential_dadd` (`getOrCreateExistentialDAddMethod` in `slang-ir-autodiff.cpp`) that performs a run-time type-id check to see if one of the operand is of type `NullDifferential` and returns the other operand if so. If both are non-null, we dispatch to the appropriate `dadd` for the concrete type.
+
+**`dadd()` and `dzero()` for "open" Existential types**
+If we are dealing with values of the concrete type (i.e. the opened value obtained through `ExtractExistentialValue(ExistentialParam)`). Then we can perform lookups in the same way we do for generic type. All existential parameters come with a witness table. We insert instructions to extract this witness table and perform lookups accordingly. That is, for `dadd()`, we use `Lookup('dadd', ExtractExistentialWitnessTable(ExistentialParam))` and place a call to the result.
+
+## `struct DifferentialPair<T:IDifferentiable>`
+The second major component is `DifferentialPair<T:IDifferentiable>` that represents a pair of a primal value and its corresponding differential value.
+The differential pair is primarily used for passing & receiving derivatives from the synthesized derivative methods, as well as for block parameters on the IR-side.
+Both `fwd_diff(fn)` and `bwd_diff(fn)` act as function-to-function transformations, and so the Slang front-end translates the type of `fn` to its derivative version so the arguments can be type checked.
+
+### Pair type lowering.
+The differential pair type is a special type throughout the AST and IR passes (AST Node: `DifferentialPairType`, IR: `kIROp_DifferentialPairType`) because of its use in front-end semantic checking and when synthesizing the derivative code for the functions. Once the auto-diff passes are complete, the pair types are lowering into simple `struct`s so they can be easily emitted (`DiffPairLoweringPass` in `slang-ir-autodiff-pairs.cpp`).
+We also define additional instructions for pair construction (`kIROp_MakeDifferentialPair`) and extraction (`kIROp_DifferentialPairGetDifferential` & `kIROp_DifferentialPairGetPrimal`) which are lowered into struct construction and field accessors, respectively.
+
+### "User-code" Differential Pairs
+Just as we use special IR codes for differential pairs because they have special handling in the IR passes, sometimes differential pairs should be _treated as_ regular struct types during the auto-diff passes.
+This happens primarily during higher-order differentiation when the user wishes to differentiate the same code multiple times.
+Slang's auto-diff approaches this by rewriting all the relevant differential pairs into 'irrelevant' differential pairs (`kIROp_DifferentialPairUserCode`) and 'irrelevant' accessors (`kIROp_DifferentialPairGetDifferentialUserCode`, `kIROp_DifferentialPairGetPrimalUserCode`) at the end of **each auto-diff iteration** so that the next iteration treats these as regular differentiable types.
+The user-code versions are also lowered into `struct`s in the same way.
+
+## Type Checking of Auto-Diff Calls (and other _higher-order_ functions)
+Since `fwd_diff` and `bwd_diff` are represented as higher order functions that take a function as an input and return the derivative function, the front-end semantic checking needs some notion of higher-order functions to be able to check and lower the calls into appropriate IR.
+
+### Higher-order Invocation Base: `HigherOrderInvokeExpr`
+All higher order transformations derive from `HigherOrderInvokeExpr`. For auto-diff there are two possible expression classes `ForwardDifferentiateExpr` and `BackwardDifferentiateExpr`, both of which derive from this parent expression.
+
+### Higher-order Function Call Checking: `HigherOrderInvokeExprCheckingActions`
+Resolving the concrete method is not a trivial issue in Slang, given its support for overloading, type coercion and more. This becomes more complex with the presence of a function transformation in the chain.
+For example, if we have `fwd_diff(f)(DiffPair<float>(...), DiffPair<double>(...))`, we would need to find the correct match for `f` based on its post-transform argument types.
+
+To facilitate this we use the following workflow:
+1. The `HigherOrderInvokeExprCheckingActions` base class provides a mechanism for different higher-order expressions to implement their type translation (i.e. what is the type of the transformed function).
+2. The checking mechanism passes all detected overloads for `f` through the type translation and assembles a new group out of the results (the new functions are 'temporary')
+3. This new group is used by `ResolveInvoke` when performing overload resolution and type coercion using the user-provided argument list.
+4. The resolved signature (if there is one) is then replaced with the corresponding function reference and wrapped in the appropriate higher-order invoke.
+
+**Example:**
+
+Let's say we have two functions with the same name `f`: (`int -> float`, `double, double -> float`)
+and we want to resolve `fwd_diff(f)(DiffPair<float>(1.0, 0.0), DiffPair<float>(0.0, 1.0))`.
+
+The higher-order checking actions will synthesize the 'temporary' group of translated signatures (`int -> DiffPair<float>`, `DiffPair<double>, DiffPair<double> -> DiffPair<float>`).
+Invoke resolution will then narrow this down to a single match (`DiffPair<double>, DiffPair<double> -> DiffPair<float>`) by automatically casting the `float`s to `double`s. Once the resolution is complete,
+we return `InvokeExpr(ForwardDifferentiateExpr(f : double, double -> float), casted_args)` by wrapping the corresponding function in the corresponding higher-order expr
+
+## Attributed Types (`no_diff` parameters)
+
+Often, it will be necessary to prevent gradients from propagating through certain parameters, for correctness reasons. For example, values representing random samples are often not differentiated since the result may be mathematically incorrect.
+
+Slang provides the `no_diff` operator to mark parameters as non-differentiable, even if they use a type that conforms to `IDifferentiable`
+
+```C
+float myFunc(float a, no_diff float b)
+{
+ return a * b;
+}
+
+// Resulting fwd-mode derivative:
+DiffPair<float> myFunc(DiffPair<float> dpa, float b)
+{
+ return diffPair(dpa.p * b, dpa.d * b);
+}
+```
+
+Slang uses _OpAttributedType_ to denote the IR type of such parameters. For example, the lowered type of `b` in the above example is `OpAttributedType(OpFloat, OpNoDiffAttr)`. In the front-end, this is represented through the `ModifiedType` AST node.
+
+Sometimes, this additional layer can get in the way of things like type equality checks and other mechanisms where the `no_diff` is irrelevant. Thus, we provide the `unwrapAttributedType` helper to remove attributed type layers for such cases.
+
+## Derivative Data-Flow Analysis
+Slang has a derivative data-flow analysis pass that is performed on a per-function basis immediately after lowering to IR and before the linking step (`slang-ir-check-differentiability.h`/`slang-ir-check-differentiability.cpp`).
+
+The job of this pass is to enforce that instructions that are of a differentiable type will propagate a derivatives, unless explicitly dropped by the user through `detach()` or `no_diff`. The reason for this is that Slang requires functions to be decorated with `[Differentiable]` to allow it to propagate derivatives. Otherwise, the function is considered non-differentiable, and effectively produces a 0 derivative. This can lead to frustrating situations where a function may be dropping non-differentiable on purpose. Example:
+```C
+float nonDiffFunc(float x)
+{
+ /* ... */
+}
+
+float differentiableFunc(float x) // Forgot to annotate with [Differentiable]
+{
+ /* ... */
+}
+
+float main(float x)
+{
+ // User doesn't realise that the function that is supposed to be differentiable is not
+ // getting differentiated, because the types here are all 'float'.
+ //
+ return nonDiffFunc(x) * differentiableFunc(x);
+}
+```
+
+The data-flow analysis step enforces that non-differentiable functions used in a differentiable context should get their derivative dropped explicitly. That way, it is clear to the user whether a call is getting differentiated or dropped.
+
+Same example with `no_diff` enforcement:
+```C
+float nonDiffFunc(float x)
+{
+ /* ... */
+}
+
+[Differentiable]
+float differentiableFunc(float x)
+{
+ /* ... */
+}
+
+float main(float x)
+{
+ return no_diff(nonDiffFunc(x)) * differentiableFunc(x);
+}
+```
+
+A `no_diff` can only be used directly on a function call, and turns into a `TreatAsDifferentiableDecoration` that indicates that the function will not produce a derivative.
+
+The derivative data-flow analysis pass works similar to a standard data-flow pass:
+1. We start by assembling a set of instructions that 'produce' derivatives by starting with the parameters of differentiable types (and without an explicit `no_diff`), and propagating them through each instruction in the block. An inst carries a derivative if there one of its operands carries a derivative, and the result type is differentiable.
+2. We then assemble a set of instructions that expect a derivative. These are differentiable operands of differentiable functions (unless they have been marked by `no_diff`). We then reverse-propagate this set by adding in all differentiable operands (and repeating this process).
+3. During this reverse-propagation, if there is any `OpCall` in the 'expect' set that is not also in the 'produce' set, then we have a situation where the gradient hasn't been explicitly dropped, and we create a user diagnostic.