From ba89fc84267bfd09f1c8abf10a5b85d09bbc79de Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 26 Jul 2023 17:15:21 -0400 Subject: Refactor `dmul(This, Differential)` to `dmul(T, Differential)` (#3029) * Refactor `dmul(This, Differential)` to `dmul(T, Differential)` - Add AST synthesis support for generic containers - Refactor relevant tests * Merge dmul synthesis with dadd and dzero, and disambiguate using an enum * Fix trailing spaces --- source/slang/core.meta.slang | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) (limited to 'source/slang/core.meta.slang') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 77b9405ba..ae70a83f4 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -112,6 +112,12 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {} interface __BuiltinIntegerType : __BuiltinArithmeticType {} +/// A type that can represent non-integers +[sealed] +[builtin] +interface __BuiltinRealType : __BuiltinSignedArithmeticType {} + + __attributeTarget(AggTypeDecl) attribute_syntax [__NonCopyableType] : NonCopyableTypeAttribute; @@ -144,7 +150,7 @@ interface IDifferentiable // Note: the compiler implementation requires the `Differential` associated type to be defined // before anything else. - __builtin_requirement($( (int) BuiltinRequirementKind::DifferentialType) ) + __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialType) ) associatedtype Differential : IDifferentiable; __builtin_requirement($( (int)BuiltinRequirementKind::DZeroFunc) ) @@ -154,7 +160,8 @@ interface IDifferentiable static Differential dadd(Differential, Differential); __builtin_requirement($( (int)BuiltinRequirementKind::DMulFunc) ) - static Differential dmul(This, Differential); + __generic + static Differential dmul(T, Differential); }; @@ -219,19 +226,16 @@ struct DifferentialPair : IDifferentiable T.Differential.dadd(a.d, b.d)); } + __generic [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { return Differential( - T.dmul(a.p, b.p), - T.Differential.dmul(a.d, b.d)); + T.dmul(a, b.p), + T.Differential.dmul(a, b.d)); } }; -/// A type that can represent non-integers -[sealed] -[builtin] -interface __BuiltinRealType : __BuiltinSignedArithmeticType {} /// A type that uses a floating-point representation [sealed] @@ -339,6 +343,9 @@ __generic __intrinsic_op(select) vector operator?:(vector __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse); __generic __intrinsic_op(select) vector select(vector condition, vector ifTrue, vector ifFalse); +// Allow real-number types to be cast into each other +__intrinsic_op($(kIROp_FloatCast)) + T __realCast(U val); ${{{{ // We are going to use code generation to produce the @@ -483,12 +490,13 @@ ${{{{ { return a + b; } - + + __generic [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(Differential a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast(a) * b; } ${{{{ break; @@ -1190,11 +1198,12 @@ extension vector : IDifferentiable return a + b; } + __generic [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast(a) * b; } } @@ -1216,12 +1225,13 @@ extension matrix : IDifferentiable { return a + b; } - + + __generic [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast(a) * b; } } -- cgit v1.2.3