From f7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 26 Oct 2022 22:21:29 -0400 Subject: Adding a differentiable standard library (#2465) --- source/slang/diff.meta.slang | 63 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 2 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index f314e0487..ea204c839 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -5,7 +5,7 @@ syntax __differentiate_jvp : JVPDerivativeModifier; // Custom JVP Function reference -__attributeTarget(FuncDecl) +__attributeTarget(FunctionDeclBase) attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; /// Interface to denote types as differentiable. @@ -39,7 +39,7 @@ extension float : IDifferentiable [__unsafeForceInlineEarly] static Differential zero() { - return 0.f; + return float(0.f); } [__unsafeForceInlineEarly] @@ -151,3 +151,62 @@ struct __DifferentialPair return p(); } }; + +typealias IDFloat = IFloat & IDifferentiable; + +namespace dstd +{ + // Natural Exponent + __generic + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_exp($0)") + __target_intrinsic(cpp, "$P_exp($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") + [__custom_jvp(d_exp)] + T exp(T x); + + __generic + __DifferentialPair d_exp(__DifferentialPair dpx) + { + return __DifferentialPair( + exp(dpx.p()), + T.dmul(exp(dpx.p()), dpx.d())); + } + + // Sine + __generic + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_sin($0)") + __target_intrinsic(cpp, "$P_sin($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") + [__custom_jvp(d_sin)] + T sin(T x); + + __generic + __DifferentialPair d_sin(__DifferentialPair dpx) + { + return __DifferentialPair( + sin(dpx.p()), + T.dmul(cos(dpx.p()), dpx.d())); + } + + // Cosine + __generic + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_cos($0)") + __target_intrinsic(cpp, "$P_cos($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") + [__custom_jvp(d_cos)] + T cos(T x); + + __generic + __DifferentialPair d_cos(__DifferentialPair dpx) + { + return __DifferentialPair( + cos(dpx.p()), + T.dmul(-sin(dpx.p()), dpx.d())); + } +}; -- cgit v1.2.3