summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-12-01 17:57:53 -0500
committerGitHub <noreply@github.com>2024-12-01 14:57:53 -0800
commit0d92d72981b7883f7aab509a640b4f52dc55a819 (patch)
tree8d3db187247ef430667345bfeaa7314faa42cb4e
parent136c2e22b80d3ebf500d09d5ce6f4fa47dcac8a0 (diff)
[Docs] Minor fixes to auto-diff documentation (#5621)
* Minor fixes to AD documentation * Add a note warning about experimental behavior * Update vulkan --------- Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com> Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--docs/user-guide/07-autodiff.md3
-rw-r--r--source/slang/core.meta.slang36
2 files changed, 28 insertions, 11 deletions
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md
index 0664d2499..2a766e1c0 100644
--- a/docs/user-guide/07-autodiff.md
+++ b/docs/user-guide/07-autodiff.md
@@ -167,6 +167,9 @@ interface IDifferentiablePtrType
}
```
+> #### Note ####
+> Support for `IDifferentiablePtrType` is still experimental.
+
Types should not conform to both `IDifferentiablePtrType` and `IDifferentiable`. Such cases will result in a compiler error.
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 817e3c39d..a9d53162c 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -416,13 +416,18 @@ void __requireGLSLExtension(constexpr String preludeText);
__intrinsic_op($(kIROp_StaticAssert))
void static_assert(constexpr bool condition, NativeString errorMessage);
-/// Interface to denote types as differentiable.
-/// Allows for user-specified differential types as
-/// well as automatic generation, for when the associated type
-/// hasn't been declared explicitly.
-/// Note that the requirements must currently be defined in this exact order
-/// since the auto-diff pass relies on the order to grab the struct keys.
+/// Represents a type that is differentiable for the purposes of automatic differentiation.
///
+/// Implemented by builtin floating-point scalar types (`float`, `half`, `double`)
+///
+/// vector<T, N>, matrix<T, N, M> and Array<T, N> automatically conform to
+/// `IDifferentiable` if `T` conforms to `IDifferentiable`.
+///
+/// @remarks Types that implement `IDifferentiable` can be used with the automatic differentiation
+/// primitives `bwd_diff` and `fwd_diff` to load and store gradients of parameters.
+/// @remarks This interface supports automatic synthesis of requirements. A struct that conforms to `IDifferentiable`
+/// will have its `Differential`, `dzero()` and `dadd()` methods automatically synthesized based on its fields, if
+/// they are not already defined.
__magic_type(DifferentiableType)
interface IDifferentiable
{
@@ -446,9 +451,13 @@ interface IDifferentiable
static Differential dmul(T, Differential);
};
-/// Represents a type that supports differentiation operations for pointer types.
-/// This interface is used to define operations that are specific to pointer types
-/// in the context of automatic differentiation.
+/// @experimental
+///
+/// Represents a type that supports differentiation operations for pointers, buffers and
+/// any other types
+///
+/// @remarks Support for this interface is still experimental and subject to change.
+///
__magic_type(DifferentiablePtrType)
interface IDifferentiablePtrType
{
@@ -458,8 +467,9 @@ interface IDifferentiablePtrType
/// Pair type that serves to wrap the primal and
-/// differential types of an arbitrary type T.
-
+/// differential types of a differentiable value type
+/// T that conforms to `IDifferentiable`.
+///
__generic<T : IDifferentiable>
__magic_type(DifferentialPairType)
__intrinsic_type($(kIROp_DifferentialPairUserCodeType))
@@ -528,6 +538,10 @@ struct DifferentialPair : IDifferentiable
}
};
+/// Pair type that serves to wrap the primal and
+/// differential types of a differentiable pointer type
+/// T that conforms to `IDifferentiablePtrType`.
+///
__generic<T : IDifferentiablePtrType>
__magic_type(DifferentialPtrPairType)
__intrinsic_type($(kIROp_DifferentialPtrPairType))