diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-26 22:21:29 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-26 19:21:29 -0700 |
| commit | f7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch) | |
| tree | 574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang/diff.meta.slang | |
| parent | 939be44ca23476e622dfb24a592383fe2a1da61f (diff) | |
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 63 |
1 files changed, 61 insertions, 2 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index f314e0487..ea204c839 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -5,7 +5,7 @@ syntax __differentiate_jvp : JVPDerivativeModifier; // Custom JVP Function reference -__attributeTarget(FuncDecl) +__attributeTarget(FunctionDeclBase) attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; /// Interface to denote types as differentiable. @@ -39,7 +39,7 @@ extension float : IDifferentiable [__unsafeForceInlineEarly] static Differential zero() { - return 0.f; + return float(0.f); } [__unsafeForceInlineEarly] @@ -151,3 +151,62 @@ struct __DifferentialPair return p(); } }; + +typealias IDFloat = IFloat & IDifferentiable; + +namespace dstd +{ + // Natural Exponent + __generic<T : IDFloat> + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_exp($0)") + __target_intrinsic(cpp, "$P_exp($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") + [__custom_jvp(d_exp<T>)] + T exp(T x); + + __generic<T : IDFloat> + __DifferentialPair<T> d_exp(__DifferentialPair<T> dpx) + { + return __DifferentialPair<T>( + exp(dpx.p()), + T.dmul(exp(dpx.p()), dpx.d())); + } + + // Sine + __generic<T : IDFloat> + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_sin($0)") + __target_intrinsic(cpp, "$P_sin($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") + [__custom_jvp(d_sin<T>)] + T sin(T x); + + __generic<T : IDFloat> + __DifferentialPair<T> d_sin(__DifferentialPair<T> dpx) + { + return __DifferentialPair<T>( + sin(dpx.p()), + T.dmul(cos(dpx.p()), dpx.d())); + } + + // Cosine + __generic<T : IDFloat> + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_cos($0)") + __target_intrinsic(cpp, "$P_cos($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") + [__custom_jvp(d_cos<T>)] + T cos(T x); + + __generic<T : IDFloat> + __DifferentialPair<T> d_cos(__DifferentialPair<T> dpx) + { + return __DifferentialPair<T>( + cos(dpx.p()), + T.dmul(-sin(dpx.p()), dpx.d())); + } +}; |
