diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-20 14:42:50 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-20 14:42:50 -0800 |
| commit | 47715e625337d489f3c0131bbc2b849378b48a5a (patch) | |
| tree | bc737c8f03ef537b2ac39860bbb922c7600edc43 /source/slang/diff.meta.slang | |
| parent | 8b05df4187117d61491f2fdbeb7d744146ad73f7 (diff) | |
Miscellaneous backward autodiff fixes. (#2665)
* Fix differentiable type registration
* Fix use of non-differentiable return value in a differentiable func.
* Fix use of primal inst that does not dominate the diff block.
* Fix primal inst hoisting, and add missing type legalization logic.
* Make `detach` defined on all differentiable T.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 24 |
1 files changed, 8 insertions, 16 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a60a77cc3..859b8a488 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -319,7 +319,13 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma // Detach and set derivatives to zero -__generic<T : __BuiltinFloatingPointType> +__generic<T : IDifferentiable> +T detach(T x) +{ + return x; +} + +__generic<T : IDifferentiable> [ForwardDerivativeOf(detach)] DifferentialPair<T> __d_detach(DifferentialPair<T> dpx) { @@ -329,27 +335,13 @@ DifferentialPair<T> __d_detach(DifferentialPair<T> dpx) ); } -__generic<T : __BuiltinFloatingPointType, let N : int> -[ForwardDerivativeOf(detach)] -DifferentialPair<vector<T, N>> __d_detach_vector(DifferentialPair<vector<T, N>> dpx) -{ - VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx); -} - -__generic<T : __BuiltinFloatingPointType> +__generic<T : IDifferentiable> [BackwardDerivativeOf(detach)] void __d_detach(inout DifferentialPair<T> dpx, T.Differential dOut) { dpx = diffPair(dpx.p, T.dzero()); } -__generic<T : __BuiltinFloatingPointType, let N : int> -[BackwardDerivativeOf(detach)] -void __d_detach_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) -{ - dpx = diffPair(dpx.p, vector<T, N>.dzero()); -} - // Natural Exponent __generic<T : __BuiltinFloatingPointType> |
