From 47715e625337d489f3c0131bbc2b849378b48a5a Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 20 Feb 2023 14:42:50 -0800 Subject: Miscellaneous backward autodiff fixes. (#2665) * Fix differentiable type registration * Fix use of non-differentiable return value in a differentiable func. * Fix use of primal inst that does not dominate the diff block. * Fix primal inst hoisting, and add missing type legalization logic. * Make `detach` defined on all differentiable T. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a60a77cc3..859b8a488 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -319,7 +319,13 @@ void mul(inout DifferentialPair> left, inout DifferentialPair +__generic +T detach(T x) +{ + return x; +} + +__generic [ForwardDerivativeOf(detach)] DifferentialPair __d_detach(DifferentialPair dpx) { @@ -329,27 +335,13 @@ DifferentialPair __d_detach(DifferentialPair dpx) ); } -__generic -[ForwardDerivativeOf(detach)] -DifferentialPair> __d_detach_vector(DifferentialPair> dpx) -{ - VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); -} - -__generic +__generic [BackwardDerivativeOf(detach)] void __d_detach(inout DifferentialPair dpx, T.Differential dOut) { dpx = diffPair(dpx.p, T.dzero()); } -__generic -[BackwardDerivativeOf(detach)] -void __d_detach_vector(inout DifferentialPair> dpx, vector.Differential dOut) -{ - dpx = diffPair(dpx.p, vector.dzero()); -} - // Natural Exponent __generic -- cgit v1.2.3