From 004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 9 Nov 2022 19:19:17 -0800 Subject: Add `[ForwardDerivativeOf]` attribute. (#2501) * Add [ForwardDerivativeOf] attribute. * Fix handling around phi nodes. * Fixes. * Remove IR opcode for ForwardDerivativeOfDecoration. Co-authored-by: Yong He --- source/slang/diff.meta.slang | 126 ++++++++++++++++--------------------------- 1 file changed, 47 insertions(+), 79 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c95f8e1ac..1f6064983 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -10,6 +10,13 @@ __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; + +__attributeTarget(DeclBase) +attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; + + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. @@ -83,85 +90,46 @@ struct DifferentialPair : IDifferentiable #define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \ vector result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result -namespace dstd +// Natural Exponent + +__generic +[ForwardDerivativeOf(exp)] +DifferentialPair __d_exp(DifferentialPair dpx) { - // Natural Exponent - __generic - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(cuda, "$P_exp($0)") - __target_intrinsic(cpp, "$P_exp($0)") - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [ForwardDerivative(d_exp)] - T exp(T x); - - __generic - DifferentialPair d_exp(DifferentialPair dpx) - { - return DifferentialPair( - dstd.exp(dpx.p), - T.dmul(dstd.exp(dpx.p), dpx.d)); - } - - // Sine - __generic - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(cuda, "$P_sin($0)") - __target_intrinsic(cpp, "$P_sin($0)") - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") - [ForwardDerivative(d_sin)] - T sin(T x); - - __generic - DifferentialPair d_sin(DifferentialPair dpx) - { - return DifferentialPair( - dstd.sin(dpx.p), - T.dmul(dstd.cos(dpx.p), dpx.d)); - } - - // Cosine - __generic - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(cuda, "$P_cos($0)") - __target_intrinsic(cpp, "$P_cos($0)") - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") - [ForwardDerivative(d_cos)] - T cos(T x); - - __generic - DifferentialPair d_cos(DifferentialPair dpx) - { - return DifferentialPair( - dstd.cos(dpx.p), - T.dmul(-dstd.sin(dpx.p), dpx.d)); - } - - __generic - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [ForwardDerivative(d_exp_vector)] - vector exp(vector x) - { - VECTOR_MAP_UNARY(float, N, dstd.exp, x); - } - - __generic - DifferentialPair> d_exp_vector(DifferentialPair> dpx) - { - vector result; - vector.Differential d_result; - for(int i = 0; i < N; ++i) - { - DifferentialPair dpexp = dstd.d_exp(DifferentialPair(dpx.p[i], dpx.d[i])); - result[i] = dpexp.p; - d_result[i] = dpexp.d; - } - - return DifferentialPair>(result, d_result); + return DifferentialPair( + exp(dpx.p), + T.dmul(exp(dpx.p), dpx.d)); +} + +__generic +[ForwardDerivativeOf(exp)] +DifferentialPair> __d_exp_vector(DifferentialPair> dpx) +{ + vector result; + vector.Differential d_result; + for(int i = 0; i < N; ++i) + { + DifferentialPair dpexp = __d_exp(DifferentialPair(dpx.p[i], __slang_noop_cast(dpx.d[i]))); + result[i] = dpexp.p; + d_result[i] = __slang_noop_cast(dpexp.d); } + return DifferentialPair>(result, d_result); +} -}; +__generic +[ForwardDerivativeOf(sin)] +DifferentialPair d_sin(DifferentialPair dpx) +{ + return DifferentialPair( + sin(dpx.p), + T.dmul(cos(dpx.p), dpx.d)); +} + +__generic +[ForwardDerivativeOf(cos)] +DifferentialPair d_cos(DifferentialPair dpx) +{ + return DifferentialPair( + cos(dpx.p), + T.dmul(-sin(dpx.p), dpx.d)); +} -- cgit v1.2.3