summaryrefslogtreecommitdiff
path: root/source/slang/core.meta.slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-09-19 03:10:28 -0400
committerGitHub <noreply@github.com>2024-09-19 00:10:28 -0700
commitccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch)
tree435e9c462a78fb848ab3b36c23287543d1a859de /source/slang/core.meta.slang
parent1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff)
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface * Initial support for `IDifferentiablePtrType` * Fix unused vars * More tests + fix switch case fallthrough. * Update slang-ir-autodiff.cpp * Update diff-ptr-type-loop.slang * Add optimization to allow more complex pair types * Update slang-ir-autodiff-primal-hoist.cpp * Update diff-ptr-type-loop.slang * Update slang-ir-autodiff-primal-hoist.cpp * More fixes to address reviews * Update slang-check-expr.cpp * Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType` * Move pair logic to ir-builder, unify the type dictionaries. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/core.meta.slang')
-rw-r--r--source/slang/core.meta.slang37
1 files changed, 37 insertions, 0 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index afcff8e65..476279ab8 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -285,6 +285,13 @@ interface IDifferentiable
static Differential dmul(T, Differential);
};
+__magic_type(DifferentiablePtrType)
+interface IDifferentiablePtrType
+{
+ __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) )
+ associatedtype Differential : IDifferentiablePtrType;
+};
+
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
@@ -357,6 +364,36 @@ struct DifferentialPair : IDifferentiable
}
};
+__generic<T : IDifferentiablePtrType>
+__magic_type(DifferentialPtrPairType)
+__intrinsic_type($(kIROp_DifferentialPtrPairType))
+struct DifferentialPtrPair : IDifferentiablePtrType
+{
+ typedef DifferentialPtrPair<T.Differential> Differential;
+ typedef T.Differential DifferentialElementType;
+
+ __intrinsic_op($(kIROp_MakeDifferentialPtrPair))
+ __init(T _primal, T.Differential _differential);
+
+ property p : T
+ {
+ __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal))
+ get;
+ }
+
+ property v : T
+ {
+ __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal))
+ get;
+ }
+
+ property d : T.Differential
+ {
+ __intrinsic_op($(kIROp_DifferentialPtrPairGetDifferential))
+ get;
+ }
+};
+
/// A type that uses a floating-point representation
[sealed]