summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-01 08:46:57 -0700
committerGitHub <noreply@github.com>2022-11-01 08:46:57 -0700
commitcbc1eff56057f199183bb7c17d8a360326512367 (patch)
tree487865e928cd2ceecbb509f0bfd06aa8d9584411 /source/slang/diff.meta.slang
parentb707a07b1de3535cb0a8ccb6fe2ed4afa4a016d1 (diff)
Make `DifferentialPair` able to nest. (#2477)
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang63
1 files changed, 56 insertions, 7 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index ad3dfe77c..674531048 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -128,13 +128,37 @@ extension vector<float, 4> : IDifferentiable
}
}
+__magic_type(DifferentialBottomType)
+__intrinsic_type($(kIROp_DifferentialBottomType))
+struct __DifferentialBottom : IDifferentiable
+{
+ typedef __DifferentialBottom Differential;
+
+ __intrinsic_op($(kIROp_DifferentialBottomValue))
+ static __DifferentialBottom dzero();
+
+ [__unsafeForceInlineEarly]
+ static __DifferentialBottom dadd(Differential a, Differential b)
+ {
+ return dzero();
+ }
+
+ [__unsafeForceInlineEarly]
+ static __DifferentialBottom dmul(This a, Differential b)
+ {
+ return dzero();
+ }
+}
+
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
__generic<T : IDifferentiable>
__magic_type(DifferentialPairType)
__intrinsic_type($(kIROp_DifferentialPairType))
-struct __DifferentialPair
+struct DifferentialPair : IDifferentiable
{
+ typedef DifferentialPair<T.Differential> Differential;
+ typedef T.Differential DifferentialElementType;
__intrinsic_op($(kIROp_MakeDifferentialPair))
__init(T _primal, T.Differential _differential);
@@ -154,6 +178,31 @@ struct __DifferentialPair
{
return p();
}
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return Differential(T.dzero(), Differential.DifferentialElementType.dzero());
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return Differential(
+ T.dadd(
+ a.p(),
+ b.p()
+ ),
+ Differential.DifferentialElementType.dzero());
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return Differential(
+ T.dmul(a.p(), b.p()),
+ Differential.DifferentialElementType.dzero());
+ }
};
typealias IDFloat = IFloat & IDifferentiable;
@@ -171,9 +220,9 @@ namespace dstd
T exp(T x);
__generic<T : IDFloat>
- __DifferentialPair<T> d_exp(__DifferentialPair<T> dpx)
+ DifferentialPair<T> d_exp(DifferentialPair<T> dpx)
{
- return __DifferentialPair<T>(
+ return DifferentialPair<T>(
exp(dpx.p()),
T.dmul(exp(dpx.p()), dpx.d()));
}
@@ -189,9 +238,9 @@ namespace dstd
T sin(T x);
__generic<T : IDFloat>
- __DifferentialPair<T> d_sin(__DifferentialPair<T> dpx)
+ DifferentialPair<T> d_sin(DifferentialPair<T> dpx)
{
- return __DifferentialPair<T>(
+ return DifferentialPair<T>(
sin(dpx.p()),
T.dmul(cos(dpx.p()), dpx.d()));
}
@@ -207,9 +256,9 @@ namespace dstd
T cos(T x);
__generic<T : IDFloat>
- __DifferentialPair<T> d_cos(__DifferentialPair<T> dpx)
+ DifferentialPair<T> d_cos(DifferentialPair<T> dpx)
{
- return __DifferentialPair<T>(
+ return DifferentialPair<T>(
cos(dpx.p()),
T.dmul(-sin(dpx.p()), dpx.d()));
}