summaryrefslogtreecommitdiff
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-09 19:19:17 -0800
committerGitHub <noreply@github.com>2022-11-09 19:19:17 -0800
commit004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch)
treecbc942746bab043da0eb5298993d95f9665dfddf /source/slang/diff.meta.slang
parentcedd93690c63188cf98e452c9d104cf51aad6c4e (diff)
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute. * Fix handling around phi nodes. * Fixes. * Remove IR opcode for ForwardDerivativeOfDecoration. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang126
1 files changed, 47 insertions, 79 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index c95f8e1ac..1f6064983 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -10,6 +10,13 @@ __attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute;
+
+__attributeTarget(DeclBase)
+attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
+
+
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
@@ -83,85 +90,46 @@ struct DifferentialPair : IDifferentiable
#define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \
vector<TYPE,COUNT> result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result
-namespace dstd
+// Natural Exponent
+
+__generic<T : __BuiltinFloatingPointType>
+[ForwardDerivativeOf(exp)]
+DifferentialPair<T> __d_exp(DifferentialPair<T> dpx)
{
- // Natural Exponent
- __generic<T : __BuiltinFloatingPointType>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
- __target_intrinsic(cuda, "$P_exp($0)")
- __target_intrinsic(cpp, "$P_exp($0)")
- __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
- [ForwardDerivative(d_exp<T>)]
- T exp(T x);
-
- __generic<T : __BuiltinFloatingPointType>
- DifferentialPair<T> d_exp(DifferentialPair<T> dpx)
- {
- return DifferentialPair<T>(
- dstd.exp(dpx.p),
- T.dmul(dstd.exp(dpx.p), dpx.d));
- }
-
- // Sine
- __generic<T : __BuiltinFloatingPointType>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
- __target_intrinsic(cuda, "$P_sin($0)")
- __target_intrinsic(cpp, "$P_sin($0)")
- __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0")
- [ForwardDerivative(d_sin<T>)]
- T sin(T x);
-
- __generic<T : __BuiltinFloatingPointType>
- DifferentialPair<T> d_sin(DifferentialPair<T> dpx)
- {
- return DifferentialPair<T>(
- dstd.sin(dpx.p),
- T.dmul(dstd.cos(dpx.p), dpx.d));
- }
-
- // Cosine
- __generic<T : __BuiltinFloatingPointType>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
- __target_intrinsic(cuda, "$P_cos($0)")
- __target_intrinsic(cpp, "$P_cos($0)")
- __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0")
- [ForwardDerivative(d_cos<T>)]
- T cos(T x);
-
- __generic<T : __BuiltinFloatingPointType>
- DifferentialPair<T> d_cos(DifferentialPair<T> dpx)
- {
- return DifferentialPair<T>(
- dstd.cos(dpx.p),
- T.dmul(-dstd.sin(dpx.p), dpx.d));
- }
-
- __generic<let N : int>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
- __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
- [ForwardDerivative(d_exp_vector)]
- vector<float, N> exp(vector<float, N> x)
- {
- VECTOR_MAP_UNARY(float, N, dstd.exp, x);
- }
-
- __generic<let N : int>
- DifferentialPair<vector<float, N>> d_exp_vector(DifferentialPair<vector<float, N>> dpx)
- {
- vector<float, N> result;
- vector<float, N>.Differential d_result;
- for(int i = 0; i < N; ++i)
- {
- DifferentialPair<float> dpexp = dstd.d_exp(DifferentialPair<float>(dpx.p[i], dpx.d[i]));
- result[i] = dpexp.p;
- d_result[i] = dpexp.d;
- }
-
- return DifferentialPair<vector<float, N>>(result, d_result);
+ return DifferentialPair<T>(
+ exp(dpx.p),
+ T.dmul(exp(dpx.p), dpx.d));
+}
+
+__generic<T:__BuiltinFloatingPointType, let N : int>
+[ForwardDerivativeOf(exp)]
+DifferentialPair<vector<T, N>> __d_exp_vector(DifferentialPair<vector<T, N>> dpx)
+{
+ vector<T, N> result;
+ vector<T, N>.Differential d_result;
+ for(int i = 0; i < N; ++i)
+ {
+ DifferentialPair<T> dpexp = __d_exp(DifferentialPair<T>(dpx.p[i], __slang_noop_cast<T.Differential>(dpx.d[i])));
+ result[i] = dpexp.p;
+ d_result[i] = __slang_noop_cast<T>(dpexp.d);
}
+ return DifferentialPair<vector<T, N>>(result, d_result);
+}
-};
+__generic<T : __BuiltinFloatingPointType>
+[ForwardDerivativeOf(sin)]
+DifferentialPair<T> d_sin(DifferentialPair<T> dpx)
+{
+ return DifferentialPair<T>(
+ sin(dpx.p),
+ T.dmul(cos(dpx.p), dpx.d));
+}
+
+__generic<T : __BuiltinFloatingPointType>
+[ForwardDerivativeOf(cos)]
+DifferentialPair<T> d_cos(DifferentialPair<T> dpx)
+{
+ return DifferentialPair<T>(
+ cos(dpx.p),
+ T.dmul(-sin(dpx.p), dpx.d));
+}