summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-20 14:42:50 -0800
committerGitHub <noreply@github.com>2023-02-20 14:42:50 -0800
commit47715e625337d489f3c0131bbc2b849378b48a5a (patch)
treebc737c8f03ef537b2ac39860bbb922c7600edc43 /source/slang/diff.meta.slang
parent8b05df4187117d61491f2fdbeb7d744146ad73f7 (diff)
Miscellaneous backward autodiff fixes. (#2665)
* Fix differentiable type registration * Fix use of non-differentiable return value in a differentiable func. * Fix use of primal inst that does not dominate the diff block. * Fix primal inst hoisting, and add missing type legalization logic. * Make `detach` defined on all differentiable T. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang24
1 files changed, 8 insertions, 16 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index a60a77cc3..859b8a488 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -319,7 +319,13 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma
// Detach and set derivatives to zero
-__generic<T : __BuiltinFloatingPointType>
+__generic<T : IDifferentiable>
+T detach(T x)
+{
+ return x;
+}
+
+__generic<T : IDifferentiable>
[ForwardDerivativeOf(detach)]
DifferentialPair<T> __d_detach(DifferentialPair<T> dpx)
{
@@ -329,27 +335,13 @@ DifferentialPair<T> __d_detach(DifferentialPair<T> dpx)
);
}
-__generic<T : __BuiltinFloatingPointType, let N : int>
-[ForwardDerivativeOf(detach)]
-DifferentialPair<vector<T, N>> __d_detach_vector(DifferentialPair<vector<T, N>> dpx)
-{
- VECTOR_MAP_D_UNARY(T, N, __d_detach, dpx);
-}
-
-__generic<T : __BuiltinFloatingPointType>
+__generic<T : IDifferentiable>
[BackwardDerivativeOf(detach)]
void __d_detach(inout DifferentialPair<T> dpx, T.Differential dOut)
{
dpx = diffPair(dpx.p, T.dzero());
}
-__generic<T : __BuiltinFloatingPointType, let N : int>
-[BackwardDerivativeOf(detach)]
-void __d_detach_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut)
-{
- dpx = diffPair(dpx.p, vector<T, N>.dzero());
-}
-
// Natural Exponent
__generic<T : __BuiltinFloatingPointType>