diff options
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 24 |
1 files changed, 8 insertions, 16 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a60a77cc3..859b8a488 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -319,7 +319,13 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma // Detach and set derivatives to zero -__generic<T : __BuiltinFloatingPointType> +__generic<T : IDifferentiable> +T detach(T x) +{ + return x; +} + +__generic<T : IDifferentiable> [ForwardDerivativeOf(detach)] DifferentialPair<T> __d_detach(DifferentialPair<T> dpx) { @@ -329,27 +335,13 @@ DifferentialPair<T> __d_detach(DifferentialPair<T> dpx) ); } -__generic<T : __BuiltinFloatingPointType, let N : int> -[ForwardDerivativeOf(detach)] -DifferentialPair<vector<T, N>> __d_detach_vector(DifferentialPair<vector<T, N>> dpx) -{ - VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); -} - -__generic<T : __BuiltinFloatingPointType> +__generic<T : IDifferentiable> [BackwardDerivativeOf(detach)] void __d_detach(inout DifferentialPair<T> dpx, T.Differential dOut) { dpx = diffPair(dpx.p, T.dzero()); } -__generic<T : __BuiltinFloatingPointType, let N : int> -[BackwardDerivativeOf(detach)] -void __d_detach_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) -{ - dpx = diffPair(dpx.p, vector<T, N>.dzero()); -} - // Natural Exponent __generic<T : __BuiltinFloatingPointType> |
