summaryrefslogtreecommitdiff
path: root/source/slang/core.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-08 10:07:57 -0800
committerGitHub <noreply@github.com>2022-11-08 10:07:57 -0800
commitbf67309454032b4f92d0bc9735b608e56b16882f (patch)
treea321fe7db0b49fa67608b935c1389354a020f59c /source/slang/core.meta.slang
parentca882a1ef46a5a8bbff50e3a1a6f973e16358634 (diff)
Make `__BuiltinFloatingPointType` conform to `IDifferentiable`. (#2499)
Diffstat (limited to 'source/slang/core.meta.slang')
-rw-r--r--source/slang/core.meta.slang79
1 files changed, 76 insertions, 3 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 05963bd11..a37124bdc 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -162,7 +162,7 @@ interface __BuiltinRealType : __BuiltinSignedArithmeticType {}
/// A type that uses a floating-point representation
[sealed]
[builtin]
-interface __BuiltinFloatingPointType : __BuiltinRealType
+interface __BuiltinFloatingPointType : __BuiltinRealType, IDifferentiable
{
/// Initialize from a 32-bit floating-point value.
__init(float value);
@@ -369,6 +369,26 @@ ${{{{
case BaseType::Double:
}}}}
static $(kBaseTypes[tt].name) getPi() { return $(kBaseTypes[tt].name)(3.14159265358979323846264338328); }
+
+ typedef $(kBaseTypes[tt].name) Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return Differential(0);
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(Differential a, Differential b)
+ {
+ return a * b;
+ }
${{{{
break;
}
@@ -891,7 +911,6 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt)
sb << " __init(" << kBaseTypes[ff].name << " value);\n";
}
}
-
sb << "}\n";
}
@@ -926,7 +945,6 @@ for( int C = 2; C <= 4; ++C )
if(rr == R && cc == C) continue;
sb << "__init(matrix<T," << rr << "," << cc << "> value);\n";
}
-
sb << "}\n";
}
@@ -935,6 +953,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt)
if(kBaseTypes[tt].tag == BaseType::Void) continue;
auto toType = kBaseTypes[tt].name;
}}}}
+
__generic<let R : int, let C : int> extension matrix<$(toType),R,C>
{
${{{{
@@ -958,6 +977,60 @@ ${{{{
}
}}}}
+__generic<T, U>
+__intrinsic_op(0)
+T __slang_noop_cast(U u);
+
+__generic<T:__BuiltinFloatingPointType, let N: int>
+extension vector<T, N> : IDifferentiable
+{
+ typedef vector<T, N> Differential;
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return Differential(__slang_noop_cast<T>(T.dzero()));
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
+__generic<T:__BuiltinFloatingPointType, let R: int, let C: int>
+extension matrix<T, R, C> : IDifferentiable
+{
+ typedef matrix<T, R, C> Differential;
+
+ __init(T val);
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return matrix<T, R, C>(__slang_noop_cast<T>(T.dzero()));
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return a * b;
+ }
+}
+
//@ public:
/// Sampling state for filtered texture fetches.