diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-07-26 17:15:21 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-26 17:15:21 -0400 |
| commit | ba89fc84267bfd09f1c8abf10a5b85d09bbc79de (patch) | |
| tree | 2c79fc5dafb89a030d22fa86cd6fa3d69a89a785 /source/slang/core.meta.slang | |
| parent | b8ade05df10a2774d3da5ef1fb2c7479ff48989a (diff) | |
Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` (#3029)
* Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)`
- Add AST synthesis support for generic containers
- Refactor relevant tests
* Merge dmul synthesis with dadd and dzero, and disambiguate using an enum
* Fix trailing spaces
Diffstat (limited to 'source/slang/core.meta.slang')
| -rw-r--r-- | source/slang/core.meta.slang | 44 |
1 files changed, 27 insertions, 17 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 77b9405ba..ae70a83f4 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -112,6 +112,12 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {} interface __BuiltinIntegerType : __BuiltinArithmeticType {} +/// A type that can represent non-integers +[sealed] +[builtin] +interface __BuiltinRealType : __BuiltinSignedArithmeticType {} + + __attributeTarget(AggTypeDecl) attribute_syntax [__NonCopyableType] : NonCopyableTypeAttribute; @@ -144,7 +150,7 @@ interface IDifferentiable // Note: the compiler implementation requires the `Differential` associated type to be defined // before anything else. - __builtin_requirement($( (int) BuiltinRequirementKind::DifferentialType) ) + __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialType) ) associatedtype Differential : IDifferentiable; __builtin_requirement($( (int)BuiltinRequirementKind::DZeroFunc) ) @@ -154,7 +160,8 @@ interface IDifferentiable static Differential dadd(Differential, Differential); __builtin_requirement($( (int)BuiltinRequirementKind::DMulFunc) ) - static Differential dmul(This, Differential); + __generic<T : __BuiltinRealType> + static Differential dmul(T, Differential); }; @@ -219,19 +226,16 @@ struct DifferentialPair : IDifferentiable T.Differential.dadd(a.d, b.d)); } + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { return Differential( - T.dmul(a.p, b.p), - T.Differential.dmul(a.d, b.d)); + T.dmul<U>(a, b.p), + T.Differential.dmul<U>(a, b.d)); } }; -/// A type that can represent non-integers -[sealed] -[builtin] -interface __BuiltinRealType : __BuiltinSignedArithmeticType {} /// A type that uses a floating-point representation [sealed] @@ -339,6 +343,9 @@ __generic<T, let N : int> __intrinsic_op(select) vector<T,N> operator?:(vector<b __generic<T> __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse); __generic<T, let N : int> __intrinsic_op(select) vector<T,N> select(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse); +// Allow real-number types to be cast into each other +__intrinsic_op($(kIROp_FloatCast)) + T __realCast<T : __BuiltinRealType, U : __BuiltinRealType>(U val); ${{{{ // We are going to use code generation to produce the @@ -483,12 +490,13 @@ ${{{{ { return a + b; } - + + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(Differential a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast<Differential, U>(a) * b; } ${{{{ break; @@ -1190,11 +1198,12 @@ extension vector<T, N> : IDifferentiable return a + b; } + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast<T, U>(a) * b; } } @@ -1216,12 +1225,13 @@ extension matrix<T, R, C> : IDifferentiable { return a + b; } - + + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast<T, U>(a) * b; } } |
