diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-08 10:07:57 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-08 10:07:57 -0800 |
| commit | bf67309454032b4f92d0bc9735b608e56b16882f (patch) | |
| tree | a321fe7db0b49fa67608b935c1389354a020f59c /source/slang/core.meta.slang | |
| parent | ca882a1ef46a5a8bbff50e3a1a6f973e16358634 (diff) | |
Make `__BuiltinFloatingPointType` conform to `IDifferentiable`. (#2499)
Diffstat (limited to 'source/slang/core.meta.slang')
| -rw-r--r-- | source/slang/core.meta.slang | 79 |
1 files changed, 76 insertions, 3 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 05963bd11..a37124bdc 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -162,7 +162,7 @@ interface __BuiltinRealType : __BuiltinSignedArithmeticType {} /// A type that uses a floating-point representation [sealed] [builtin] -interface __BuiltinFloatingPointType : __BuiltinRealType +interface __BuiltinFloatingPointType : __BuiltinRealType, IDifferentiable { /// Initialize from a 32-bit floating-point value. __init(float value); @@ -369,6 +369,26 @@ ${{{{ case BaseType::Double: }}}} static $(kBaseTypes[tt].name) getPi() { return $(kBaseTypes[tt].name)(3.14159265358979323846264338328); } + + typedef $(kBaseTypes[tt].name) Differential; + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return Differential(0); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(Differential a, Differential b) + { + return a * b; + } ${{{{ break; } @@ -891,7 +911,6 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) sb << " __init(" << kBaseTypes[ff].name << " value);\n"; } } - sb << "}\n"; } @@ -926,7 +945,6 @@ for( int C = 2; C <= 4; ++C ) if(rr == R && cc == C) continue; sb << "__init(matrix<T," << rr << "," << cc << "> value);\n"; } - sb << "}\n"; } @@ -935,6 +953,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) if(kBaseTypes[tt].tag == BaseType::Void) continue; auto toType = kBaseTypes[tt].name; }}}} + __generic<let R : int, let C : int> extension matrix<$(toType),R,C> { ${{{{ @@ -958,6 +977,60 @@ ${{{{ } }}}} +__generic<T, U> +__intrinsic_op(0) +T __slang_noop_cast(U u); + +__generic<T:__BuiltinFloatingPointType, let N: int> +extension vector<T, N> : IDifferentiable +{ + typedef vector<T, N> Differential; + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return Differential(__slang_noop_cast<T>(T.dzero())); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + +__generic<T:__BuiltinFloatingPointType, let R: int, let C: int> +extension matrix<T, R, C> : IDifferentiable +{ + typedef matrix<T, R, C> Differential; + + __init(T val); + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return matrix<T, R, C>(__slang_noop_cast<T>(T.dzero())); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return a * b; + } +} + //@ public: /// Sampling state for filtered texture fetches. |
