diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-01 08:46:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-01 08:46:57 -0700 |
| commit | cbc1eff56057f199183bb7c17d8a360326512367 (patch) | |
| tree | 487865e928cd2ceecbb509f0bfd06aa8d9584411 /source/slang/diff.meta.slang | |
| parent | b707a07b1de3535cb0a8ccb6fe2ed4afa4a016d1 (diff) | |
Make `DifferentialPair` able to nest. (#2477)
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 63 |
1 files changed, 56 insertions, 7 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index ad3dfe77c..674531048 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -128,13 +128,37 @@ extension vector<float, 4> : IDifferentiable } } +__magic_type(DifferentialBottomType) +__intrinsic_type($(kIROp_DifferentialBottomType)) +struct __DifferentialBottom : IDifferentiable +{ + typedef __DifferentialBottom Differential; + + __intrinsic_op($(kIROp_DifferentialBottomValue)) + static __DifferentialBottom dzero(); + + [__unsafeForceInlineEarly] + static __DifferentialBottom dadd(Differential a, Differential b) + { + return dzero(); + } + + [__unsafeForceInlineEarly] + static __DifferentialBottom dmul(This a, Differential b) + { + return dzero(); + } +} + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. __generic<T : IDifferentiable> __magic_type(DifferentialPairType) __intrinsic_type($(kIROp_DifferentialPairType)) -struct __DifferentialPair +struct DifferentialPair : IDifferentiable { + typedef DifferentialPair<T.Differential> Differential; + typedef T.Differential DifferentialElementType; __intrinsic_op($(kIROp_MakeDifferentialPair)) __init(T _primal, T.Differential _differential); @@ -154,6 +178,31 @@ struct __DifferentialPair { return p(); } + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return Differential(T.dzero(), Differential.DifferentialElementType.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return Differential( + T.dadd( + a.p(), + b.p() + ), + Differential.DifferentialElementType.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return Differential( + T.dmul(a.p(), b.p()), + Differential.DifferentialElementType.dzero()); + } }; typealias IDFloat = IFloat & IDifferentiable; @@ -171,9 +220,9 @@ namespace dstd T exp(T x); __generic<T : IDFloat> - __DifferentialPair<T> d_exp(__DifferentialPair<T> dpx) + DifferentialPair<T> d_exp(DifferentialPair<T> dpx) { - return __DifferentialPair<T>( + return DifferentialPair<T>( exp(dpx.p()), T.dmul(exp(dpx.p()), dpx.d())); } @@ -189,9 +238,9 @@ namespace dstd T sin(T x); __generic<T : IDFloat> - __DifferentialPair<T> d_sin(__DifferentialPair<T> dpx) + DifferentialPair<T> d_sin(DifferentialPair<T> dpx) { - return __DifferentialPair<T>( + return DifferentialPair<T>( sin(dpx.p()), T.dmul(cos(dpx.p()), dpx.d())); } @@ -207,9 +256,9 @@ namespace dstd T cos(T x); __generic<T : IDFloat> - __DifferentialPair<T> d_cos(__DifferentialPair<T> dpx) + DifferentialPair<T> d_cos(DifferentialPair<T> dpx) { - return __DifferentialPair<T>( + return DifferentialPair<T>( cos(dpx.p()), T.dmul(-sin(dpx.p()), dpx.d())); } |
