From cbc1eff56057f199183bb7c17d8a360326512367 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 1 Nov 2022 08:46:57 -0700 Subject: Make `DifferentialPair` able to nest. (#2477) --- source/slang/diff.meta.slang | 63 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 7 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index ad3dfe77c..674531048 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -128,13 +128,37 @@ extension vector : IDifferentiable } } +__magic_type(DifferentialBottomType) +__intrinsic_type($(kIROp_DifferentialBottomType)) +struct __DifferentialBottom : IDifferentiable +{ + typedef __DifferentialBottom Differential; + + __intrinsic_op($(kIROp_DifferentialBottomValue)) + static __DifferentialBottom dzero(); + + [__unsafeForceInlineEarly] + static __DifferentialBottom dadd(Differential a, Differential b) + { + return dzero(); + } + + [__unsafeForceInlineEarly] + static __DifferentialBottom dmul(This a, Differential b) + { + return dzero(); + } +} + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. __generic __magic_type(DifferentialPairType) __intrinsic_type($(kIROp_DifferentialPairType)) -struct __DifferentialPair +struct DifferentialPair : IDifferentiable { + typedef DifferentialPair Differential; + typedef T.Differential DifferentialElementType; __intrinsic_op($(kIROp_MakeDifferentialPair)) __init(T _primal, T.Differential _differential); @@ -154,6 +178,31 @@ struct __DifferentialPair { return p(); } + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return Differential(T.dzero(), Differential.DifferentialElementType.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return Differential( + T.dadd( + a.p(), + b.p() + ), + Differential.DifferentialElementType.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return Differential( + T.dmul(a.p(), b.p()), + Differential.DifferentialElementType.dzero()); + } }; typealias IDFloat = IFloat & IDifferentiable; @@ -171,9 +220,9 @@ namespace dstd T exp(T x); __generic - __DifferentialPair d_exp(__DifferentialPair dpx) + DifferentialPair d_exp(DifferentialPair dpx) { - return __DifferentialPair( + return DifferentialPair( exp(dpx.p()), T.dmul(exp(dpx.p()), dpx.d())); } @@ -189,9 +238,9 @@ namespace dstd T sin(T x); __generic - __DifferentialPair d_sin(__DifferentialPair dpx) + DifferentialPair d_sin(DifferentialPair dpx) { - return __DifferentialPair( + return DifferentialPair( sin(dpx.p()), T.dmul(cos(dpx.p()), dpx.d())); } @@ -207,9 +256,9 @@ namespace dstd T cos(T x); __generic - __DifferentialPair d_cos(__DifferentialPair dpx) + DifferentialPair d_cos(DifferentialPair dpx) { - return __DifferentialPair( + return DifferentialPair( cos(dpx.p()), T.dmul(-sin(dpx.p()), dpx.d())); } -- cgit v1.2.3