diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-09 19:19:17 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-09 19:19:17 -0800 |
| commit | 004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch) | |
| tree | cbc942746bab043da0eb5298993d95f9665dfddf /source/slang/diff.meta.slang | |
| parent | cedd93690c63188cf98e452c9d104cf51aad6c4e (diff) | |
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute.
* Fix handling around phi nodes.
* Fixes.
* Remove IR opcode for ForwardDerivativeOfDecoration.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 126 |
1 files changed, 47 insertions, 79 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c95f8e1ac..1f6064983 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -10,6 +10,13 @@ __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; + +__attributeTarget(DeclBase) +attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; + + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. @@ -83,85 +90,46 @@ struct DifferentialPair : IDifferentiable #define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \ vector<TYPE,COUNT> result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result -namespace dstd +// Natural Exponent + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(exp)] +DifferentialPair<T> __d_exp(DifferentialPair<T> dpx) { - // Natural Exponent - __generic<T : __BuiltinFloatingPointType> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(cuda, "$P_exp($0)") - __target_intrinsic(cpp, "$P_exp($0)") - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [ForwardDerivative(d_exp<T>)] - T exp(T x); - - __generic<T : __BuiltinFloatingPointType> - DifferentialPair<T> d_exp(DifferentialPair<T> dpx) - { - return DifferentialPair<T>( - dstd.exp(dpx.p), - T.dmul(dstd.exp(dpx.p), dpx.d)); - } - - // Sine - __generic<T : __BuiltinFloatingPointType> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(cuda, "$P_sin($0)") - __target_intrinsic(cpp, "$P_sin($0)") - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") - [ForwardDerivative(d_sin<T>)] - T sin(T x); - - __generic<T : __BuiltinFloatingPointType> - DifferentialPair<T> d_sin(DifferentialPair<T> dpx) - { - return DifferentialPair<T>( - dstd.sin(dpx.p), - T.dmul(dstd.cos(dpx.p), dpx.d)); - } - - // Cosine - __generic<T : __BuiltinFloatingPointType> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(cuda, "$P_cos($0)") - __target_intrinsic(cpp, "$P_cos($0)") - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") - [ForwardDerivative(d_cos<T>)] - T cos(T x); - - __generic<T : __BuiltinFloatingPointType> - DifferentialPair<T> d_cos(DifferentialPair<T> dpx) - { - return DifferentialPair<T>( - dstd.cos(dpx.p), - T.dmul(-dstd.sin(dpx.p), dpx.d)); - } - - __generic<let N : int> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [ForwardDerivative(d_exp_vector)] - vector<float, N> exp(vector<float, N> x) - { - VECTOR_MAP_UNARY(float, N, dstd.exp, x); - } - - __generic<let N : int> - DifferentialPair<vector<float, N>> d_exp_vector(DifferentialPair<vector<float, N>> dpx) - { - vector<float, N> result; - vector<float, N>.Differential d_result; - for(int i = 0; i < N; ++i) - { - DifferentialPair<float> dpexp = dstd.d_exp(DifferentialPair<float>(dpx.p[i], dpx.d[i])); - result[i] = dpexp.p; - d_result[i] = dpexp.d; - } - - return DifferentialPair<vector<float, N>>(result, d_result); + return DifferentialPair<T>( + exp(dpx.p), + T.dmul(exp(dpx.p), dpx.d)); +} + +__generic<T:__BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(exp)] +DifferentialPair<vector<T, N>> __d_exp_vector(DifferentialPair<vector<T, N>> dpx) +{ + vector<T, N> result; + vector<T, N>.Differential d_result; + for(int i = 0; i < N; ++i) + { + DifferentialPair<T> dpexp = __d_exp(DifferentialPair<T>(dpx.p[i], __slang_noop_cast<T.Differential>(dpx.d[i]))); + result[i] = dpexp.p; + d_result[i] = __slang_noop_cast<T>(dpexp.d); } + return DifferentialPair<vector<T, N>>(result, d_result); +} -}; +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(sin)] +DifferentialPair<T> d_sin(DifferentialPair<T> dpx) +{ + return DifferentialPair<T>( + sin(dpx.p), + T.dmul(cos(dpx.p), dpx.d)); +} + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(cos)] +DifferentialPair<T> d_cos(DifferentialPair<T> dpx) +{ + return DifferentialPair<T>( + cos(dpx.p), + T.dmul(-sin(dpx.p), dpx.d)); +} |
