diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-09-19 03:10:28 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-19 00:10:28 -0700 |
| commit | ccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch) | |
| tree | 435e9c462a78fb848ab3b36c23287543d1a859de /source/slang/core.meta.slang | |
| parent | 1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff) | |
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface
* Initial support for `IDifferentiablePtrType`
* Fix unused vars
* More tests + fix switch case fallthrough.
* Update slang-ir-autodiff.cpp
* Update diff-ptr-type-loop.slang
* Add optimization to allow more complex pair types
* Update slang-ir-autodiff-primal-hoist.cpp
* Update diff-ptr-type-loop.slang
* Update slang-ir-autodiff-primal-hoist.cpp
* More fixes to address reviews
* Update slang-check-expr.cpp
* Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType`
* Move pair logic to ir-builder, unify the type dictionaries.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/core.meta.slang')
| -rw-r--r-- | source/slang/core.meta.slang | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index afcff8e65..476279ab8 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -285,6 +285,13 @@ interface IDifferentiable static Differential dmul(T, Differential); }; +__magic_type(DifferentiablePtrType) +interface IDifferentiablePtrType +{ + __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) ) + associatedtype Differential : IDifferentiablePtrType; +}; + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. @@ -357,6 +364,36 @@ struct DifferentialPair : IDifferentiable } }; +__generic<T : IDifferentiablePtrType> +__magic_type(DifferentialPtrPairType) +__intrinsic_type($(kIROp_DifferentialPtrPairType)) +struct DifferentialPtrPair : IDifferentiablePtrType +{ + typedef DifferentialPtrPair<T.Differential> Differential; + typedef T.Differential DifferentialElementType; + + __intrinsic_op($(kIROp_MakeDifferentialPtrPair)) + __init(T _primal, T.Differential _differential); + + property p : T + { + __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal)) + get; + } + + property v : T + { + __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal)) + get; + } + + property d : T.Differential + { + __intrinsic_op($(kIROp_DifferentialPtrPairGetDifferential)) + get; + } +}; + /// A type that uses a floating-point representation [sealed] |
