summaryrefslogtreecommitdiff
path: root/source/slang/core.meta.slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-07-26 17:15:21 -0400
committerGitHub <noreply@github.com>2023-07-26 17:15:21 -0400
commitba89fc84267bfd09f1c8abf10a5b85d09bbc79de (patch)
tree2c79fc5dafb89a030d22fa86cd6fa3d69a89a785 /source/slang/core.meta.slang
parentb8ade05df10a2774d3da5ef1fb2c7479ff48989a (diff)
Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` (#3029)
* Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` - Add AST synthesis support for generic containers - Refactor relevant tests * Merge dmul synthesis with dadd and dzero, and disambiguate using an enum * Fix trailing spaces
Diffstat (limited to 'source/slang/core.meta.slang')
-rw-r--r--source/slang/core.meta.slang44
1 files changed, 27 insertions, 17 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 77b9405ba..ae70a83f4 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -112,6 +112,12 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {}
interface __BuiltinIntegerType : __BuiltinArithmeticType
{}
+/// A type that can represent non-integers
+[sealed]
+[builtin]
+interface __BuiltinRealType : __BuiltinSignedArithmeticType {}
+
+
__attributeTarget(AggTypeDecl)
attribute_syntax [__NonCopyableType] : NonCopyableTypeAttribute;
@@ -144,7 +150,7 @@ interface IDifferentiable
// Note: the compiler implementation requires the `Differential` associated type to be defined
// before anything else.
- __builtin_requirement($( (int) BuiltinRequirementKind::DifferentialType) )
+ __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialType) )
associatedtype Differential : IDifferentiable;
__builtin_requirement($( (int)BuiltinRequirementKind::DZeroFunc) )
@@ -154,7 +160,8 @@ interface IDifferentiable
static Differential dadd(Differential, Differential);
__builtin_requirement($( (int)BuiltinRequirementKind::DMulFunc) )
- static Differential dmul(This, Differential);
+ __generic<T : __BuiltinRealType>
+ static Differential dmul(T, Differential);
};
@@ -219,19 +226,16 @@ struct DifferentialPair : IDifferentiable
T.Differential.dadd(a.d, b.d));
}
+ __generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
+ static Differential dmul(U a, Differential b)
{
return Differential(
- T.dmul(a.p, b.p),
- T.Differential.dmul(a.d, b.d));
+ T.dmul<U>(a, b.p),
+ T.Differential.dmul<U>(a, b.d));
}
};
-/// A type that can represent non-integers
-[sealed]
-[builtin]
-interface __BuiltinRealType : __BuiltinSignedArithmeticType {}
/// A type that uses a floating-point representation
[sealed]
@@ -339,6 +343,9 @@ __generic<T, let N : int> __intrinsic_op(select) vector<T,N> operator?:(vector<b
__generic<T> __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse);
__generic<T, let N : int> __intrinsic_op(select) vector<T,N> select(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse);
+// Allow real-number types to be cast into each other
+__intrinsic_op($(kIROp_FloatCast))
+ T __realCast<T : __BuiltinRealType, U : __BuiltinRealType>(U val);
${{{{
// We are going to use code generation to produce the
@@ -483,12 +490,13 @@ ${{{{
{
return a + b;
}
-
+
+ __generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
[BackwardDifferentiable]
- static Differential dmul(Differential a, Differential b)
+ static Differential dmul(U a, Differential b)
{
- return a * b;
+ return __realCast<Differential, U>(a) * b;
}
${{{{
break;
@@ -1190,11 +1198,12 @@ extension vector<T, N> : IDifferentiable
return a + b;
}
+ __generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
[BackwardDifferentiable]
- static Differential dmul(This a, Differential b)
+ static Differential dmul(U a, Differential b)
{
- return a * b;
+ return __realCast<T, U>(a) * b;
}
}
@@ -1216,12 +1225,13 @@ extension matrix<T, R, C> : IDifferentiable
{
return a + b;
}
-
+
+ __generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
[BackwardDifferentiable]
- static Differential dmul(This a, Differential b)
+ static Differential dmul(U a, Differential b)
{
- return a * b;
+ return __realCast<T, U>(a) * b;
}
}