summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-04 09:36:23 -0700
committerGitHub <noreply@github.com>2022-11-04 09:36:23 -0700
commitc6e6b7a9177bf4f7fc2f05da36c5952979006d78 (patch)
tree6db694b5b4bf94ce48678c73921676f9d305614d /source/slang/diff.meta.slang
parent015bde8d5a46f32979c00dbb1feb4b3d80729c44 (diff)
Higher order differentiation. (#2487)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang58
1 files changed, 6 insertions, 52 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 1c3066e1d..ae4db603e 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -9,32 +9,6 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
-/// Interface to denote types as differentiable.
-/// Allows for user-specified differential types as
-/// well as automatic generation, for when the associated type
-/// hasn't been declared explicitly.
-/// Note that the requirements must currently be defined in this exact order
-/// since the auto-diff pass relies on the order to grab the struct keys.
-///
-__magic_type(DifferentiableType)
-interface IDifferentiable
-{
- // Note: the compiler implementation requires the `Differential` associated type to be defined
- // before anything else.
-
- [__BuiltinRequirement(_BuiltinRequirementKind.DifferentialType)]
- associatedtype Differential;
-
- [__BuiltinRequirement(_BuiltinRequirementKind.DZeroFunc)]
- static Differential dzero();
-
- [__BuiltinRequirement(_BuiltinRequirementKind.DAddFunc)]
- static Differential dadd(Differential, Differential);
-
- [__BuiltinRequirement(_BuiltinRequirementKind.DMulFunc)]
- static Differential dmul(This, Differential);
-};
-
// Add extensions for the standard types
extension float : IDifferentiable
{
@@ -83,28 +57,6 @@ extension vector<float, N> : IDifferentiable
}
}
-__magic_type(DifferentialBottomType)
-__intrinsic_type($(kIROp_DifferentialBottomType))
-struct __DifferentialBottom : IDifferentiable
-{
- typedef __DifferentialBottom Differential;
-
- __intrinsic_op($(kIROp_DifferentialBottomValue))
- static __DifferentialBottom dzero();
-
- [__unsafeForceInlineEarly]
- static __DifferentialBottom dadd(Differential a, Differential b)
- {
- return dzero();
- }
-
- [__unsafeForceInlineEarly]
- static __DifferentialBottom dmul(This a, Differential b)
- {
- return dzero();
- }
-}
-
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
__generic<T : IDifferentiable>
@@ -121,6 +73,7 @@ struct DifferentialPair : IDifferentiable
__intrinsic_op($(kIROp_DifferentialPairGetDifferential))
T.Differential d();
+ [__unsafeForceInlineEarly]
T.Differential getDifferential()
{
return d();
@@ -129,6 +82,7 @@ struct DifferentialPair : IDifferentiable
__intrinsic_op($(kIROp_DifferentialPairGetPrimal))
T p();
+ [__unsafeForceInlineEarly]
T getPrimal()
{
return p();
@@ -137,7 +91,7 @@ struct DifferentialPair : IDifferentiable
[__unsafeForceInlineEarly]
static Differential dzero()
{
- return Differential(T.dzero(), Differential.DifferentialElementType.dzero());
+ return Differential(T.dzero(), T.Differential.dzero());
}
[__unsafeForceInlineEarly]
@@ -148,15 +102,15 @@ struct DifferentialPair : IDifferentiable
a.p(),
b.p()
),
- Differential.DifferentialElementType.dzero());
+ T.Differential.dadd(a.d(), b.d()));
}
[__unsafeForceInlineEarly]
static Differential dmul(This a, Differential b)
{
return Differential(
- T.dmul(a.p(), b.p()),
- Differential.DifferentialElementType.dzero());
+ T.dmul(a.p(), b.p()),
+ T.Differential.dmul(a.d(), b.d()));
}
};