diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-04 09:36:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-04 09:36:23 -0700 |
| commit | c6e6b7a9177bf4f7fc2f05da36c5952979006d78 (patch) | |
| tree | 6db694b5b4bf94ce48678c73921676f9d305614d /source/slang/diff.meta.slang | |
| parent | 015bde8d5a46f32979c00dbb1feb4b3d80729c44 (diff) | |
Higher order differentiation. (#2487)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 58 |
1 files changed, 6 insertions, 52 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 1c3066e1d..ae4db603e 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -9,32 +9,6 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; -/// Interface to denote types as differentiable. -/// Allows for user-specified differential types as -/// well as automatic generation, for when the associated type -/// hasn't been declared explicitly. -/// Note that the requirements must currently be defined in this exact order -/// since the auto-diff pass relies on the order to grab the struct keys. -/// -__magic_type(DifferentiableType) -interface IDifferentiable -{ - // Note: the compiler implementation requires the `Differential` associated type to be defined - // before anything else. - - [__BuiltinRequirement(_BuiltinRequirementKind.DifferentialType)] - associatedtype Differential; - - [__BuiltinRequirement(_BuiltinRequirementKind.DZeroFunc)] - static Differential dzero(); - - [__BuiltinRequirement(_BuiltinRequirementKind.DAddFunc)] - static Differential dadd(Differential, Differential); - - [__BuiltinRequirement(_BuiltinRequirementKind.DMulFunc)] - static Differential dmul(This, Differential); -}; - // Add extensions for the standard types extension float : IDifferentiable { @@ -83,28 +57,6 @@ extension vector<float, N> : IDifferentiable } } -__magic_type(DifferentialBottomType) -__intrinsic_type($(kIROp_DifferentialBottomType)) -struct __DifferentialBottom : IDifferentiable -{ - typedef __DifferentialBottom Differential; - - __intrinsic_op($(kIROp_DifferentialBottomValue)) - static __DifferentialBottom dzero(); - - [__unsafeForceInlineEarly] - static __DifferentialBottom dadd(Differential a, Differential b) - { - return dzero(); - } - - [__unsafeForceInlineEarly] - static __DifferentialBottom dmul(This a, Differential b) - { - return dzero(); - } -} - /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. __generic<T : IDifferentiable> @@ -121,6 +73,7 @@ struct DifferentialPair : IDifferentiable __intrinsic_op($(kIROp_DifferentialPairGetDifferential)) T.Differential d(); + [__unsafeForceInlineEarly] T.Differential getDifferential() { return d(); @@ -129,6 +82,7 @@ struct DifferentialPair : IDifferentiable __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) T p(); + [__unsafeForceInlineEarly] T getPrimal() { return p(); @@ -137,7 +91,7 @@ struct DifferentialPair : IDifferentiable [__unsafeForceInlineEarly] static Differential dzero() { - return Differential(T.dzero(), Differential.DifferentialElementType.dzero()); + return Differential(T.dzero(), T.Differential.dzero()); } [__unsafeForceInlineEarly] @@ -148,15 +102,15 @@ struct DifferentialPair : IDifferentiable a.p(), b.p() ), - Differential.DifferentialElementType.dzero()); + T.Differential.dadd(a.d(), b.d())); } [__unsafeForceInlineEarly] static Differential dmul(This a, Differential b) { return Differential( - T.dmul(a.p(), b.p()), - Differential.DifferentialElementType.dzero()); + T.dmul(a.p(), b.p()), + T.Differential.dmul(a.d(), b.d())); } }; |
