summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-26 22:21:29 -0400
committerGitHub <noreply@github.com>2022-10-26 19:21:29 -0700
commitf7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch)
tree574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang/diff.meta.slang
parent939be44ca23476e622dfb24a592383fe2a1da61f (diff)
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang63
1 files changed, 61 insertions, 2 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index f314e0487..ea204c839 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -5,7 +5,7 @@
syntax __differentiate_jvp : JVPDerivativeModifier;
// Custom JVP Function reference
-__attributeTarget(FuncDecl)
+__attributeTarget(FunctionDeclBase)
attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
/// Interface to denote types as differentiable.
@@ -39,7 +39,7 @@ extension float : IDifferentiable
[__unsafeForceInlineEarly]
static Differential zero()
{
- return 0.f;
+ return float(0.f);
}
[__unsafeForceInlineEarly]
@@ -151,3 +151,62 @@ struct __DifferentialPair
return p();
}
};
+
+typealias IDFloat = IFloat & IDifferentiable;
+
+namespace dstd
+{
+ // Natural Exponent
+ __generic<T : IDFloat>
+ __target_intrinsic(hlsl)
+ __target_intrinsic(glsl)
+ __target_intrinsic(cuda, "$P_exp($0)")
+ __target_intrinsic(cpp, "$P_exp($0)")
+ __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
+ [__custom_jvp(d_exp<T>)]
+ T exp(T x);
+
+ __generic<T : IDFloat>
+ __DifferentialPair<T> d_exp(__DifferentialPair<T> dpx)
+ {
+ return __DifferentialPair<T>(
+ exp(dpx.p()),
+ T.dmul(exp(dpx.p()), dpx.d()));
+ }
+
+ // Sine
+ __generic<T : IDFloat>
+ __target_intrinsic(hlsl)
+ __target_intrinsic(glsl)
+ __target_intrinsic(cuda, "$P_sin($0)")
+ __target_intrinsic(cpp, "$P_sin($0)")
+ __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0")
+ [__custom_jvp(d_sin<T>)]
+ T sin(T x);
+
+ __generic<T : IDFloat>
+ __DifferentialPair<T> d_sin(__DifferentialPair<T> dpx)
+ {
+ return __DifferentialPair<T>(
+ sin(dpx.p()),
+ T.dmul(cos(dpx.p()), dpx.d()));
+ }
+
+ // Cosine
+ __generic<T : IDFloat>
+ __target_intrinsic(hlsl)
+ __target_intrinsic(glsl)
+ __target_intrinsic(cuda, "$P_cos($0)")
+ __target_intrinsic(cpp, "$P_cos($0)")
+ __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0")
+ [__custom_jvp(d_cos<T>)]
+ T cos(T x);
+
+ __generic<T : IDFloat>
+ __DifferentialPair<T> d_cos(__DifferentialPair<T> dpx)
+ {
+ return __DifferentialPair<T>(
+ cos(dpx.p()),
+ T.dmul(-sin(dpx.p()), dpx.d()));
+ }
+};