diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-08 10:07:57 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-08 10:07:57 -0800 |
| commit | bf67309454032b4f92d0bc9735b608e56b16882f (patch) | |
| tree | a321fe7db0b49fa67608b935c1389354a020f59c /source/slang/diff.meta.slang | |
| parent | ca882a1ef46a5a8bbff50e3a1a6f973e16358634 (diff) | |
Make `__BuiltinFloatingPointType` conform to `IDifferentiable`. (#2499)
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 76 |
1 files changed, 14 insertions, 62 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 2625d79b0..c95f8e1ac 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -9,56 +9,10 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; -// Add extensions for the standard types -extension float : IDifferentiable -{ - typedef float Differential; - - [__unsafeForceInlineEarly] - static Differential dzero() - { - return float(0.f); - } - - [__unsafeForceInlineEarly] - static Differential dadd(Differential a, Differential b) - { - return a + b; - } - - [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) - { - return a * b; - } -} - -__generic<let N:int> -extension vector<float, N> : IDifferentiable -{ - typedef vector<float, N> Differential; - - [__unsafeForceInlineEarly] - static Differential dzero() - { - return vector<float, N>(0.f); - } - - [__unsafeForceInlineEarly] - static Differential dadd(Differential a, Differential b) - { - return a + b; - } - [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) - { - return a * b; - } -} +/// Pair type that serves to wrap the primal and +/// differential types of an arbitrary type T. - /// Pair type that serves to wrap the primal and - /// differential types of an arbitrary type T. __generic<T : IDifferentiable> __magic_type(DifferentialPairType) __intrinsic_type($(kIROp_DifferentialPairType)) @@ -126,15 +80,13 @@ struct DifferentialPair : IDifferentiable } }; -typealias IDFloat = IFloat & IDifferentiable; - #define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \ vector<TYPE,COUNT> result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result namespace dstd { // Natural Exponent - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_exp($0)") @@ -143,16 +95,16 @@ namespace dstd [ForwardDerivative(d_exp<T>)] T exp(T x); - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> DifferentialPair<T> d_exp(DifferentialPair<T> dpx) { return DifferentialPair<T>( - exp(dpx.p), - T.dmul(exp(dpx.p), dpx.d)); + dstd.exp(dpx.p), + T.dmul(dstd.exp(dpx.p), dpx.d)); } // Sine - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_sin($0)") @@ -161,16 +113,16 @@ namespace dstd [ForwardDerivative(d_sin<T>)] T sin(T x); - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> DifferentialPair<T> d_sin(DifferentialPair<T> dpx) { return DifferentialPair<T>( - sin(dpx.p), - T.dmul(cos(dpx.p), dpx.d)); + dstd.sin(dpx.p), + T.dmul(dstd.cos(dpx.p), dpx.d)); } // Cosine - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_cos($0)") @@ -179,12 +131,12 @@ namespace dstd [ForwardDerivative(d_cos<T>)] T cos(T x); - __generic<T : IDFloat> + __generic<T : __BuiltinFloatingPointType> DifferentialPair<T> d_cos(DifferentialPair<T> dpx) { return DifferentialPair<T>( - cos(dpx.p), - T.dmul(-sin(dpx.p), dpx.d)); + dstd.cos(dpx.p), + T.dmul(-dstd.sin(dpx.p), dpx.d)); } __generic<let N : int> |
