From 228e71dab7dfa18ece979f4099ec0c7d1e37e5ff Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 3 Feb 2023 16:44:33 -0800 Subject: Overhaul `transposeParameterBlock` to support `inout` params. (#2621) * Overhaul `transposeParameterBlock` to support `inout` params. * Small bug fixes. * Bug fix on differentiable intrinsic specialization. * Fixes. * Run autodiff tests on CPU. * Clean up. * More bug fixes., * Add test coverage on inout param. * Fix language server hinting for transcribed mutable params. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index adbf8ae48..055c44135 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -313,6 +313,24 @@ DifferentialPair> __d_sin_vector(DifferentialPair> dpx VECTOR_MAP_D_UNARY(T, N, __d_sin, dpx); } +__generic +[BackwardDerivativeOf(sin)] +void __d_sin(inout DifferentialPair dpx, T.Differential dOut) +{ + dpx = diffPair( + dpx.p, + T.dmul(cos(dpx.p), dOut)); +} + +__generic +[BackwardDerivativeOf(sin)] +void __d_sin_vector(inout DifferentialPair> dpx, vector.Differential dOut) +{ + dpx = diffPair( + dpx.p, + vector.dmul(cos(dpx.p), dOut)); +} + // Cosine __generic @@ -331,6 +349,24 @@ DifferentialPair> __d_cos_vector(DifferentialPair> dpx VECTOR_MAP_D_UNARY(T, N, __d_cos, dpx); } +__generic +[BackwardDerivativeOf(cos)] +void __d_cos(inout DifferentialPair dpx, T.Differential dOut) +{ + dpx = diffPair( + dpx.p, + T.dmul(-sin(dpx.p), dOut)); +} + +__generic +[BackwardDerivativeOf(cos)] +void __d_cos_vector(inout DifferentialPair> dpx, vector.Differential dOut) +{ + dpx = diffPair( + dpx.p, + vector.dmul(-sin(dpx.p), dOut)); +} + // Base-e logarithm __generic -- cgit v1.2.3