diff options
| author | Yong He <yonghe@outlook.com> | 2023-05-14 15:09:22 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-14 15:09:22 -0700 |
| commit | 6ac0c6a688b33965ba83c18e68861f8f9c4f5250 (patch) | |
| tree | 04a0ab71ffd2b932a4c4ee1d3b3127507211d443 | |
| parent | cd6064201eb8443918054588002a442459113ed4 (diff) | |
Add [Differentiable(n)] syntax to specify max order. (#2883)
| -rw-r--r-- | source/slang/core.meta.slang | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 7 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop.slang | 2 |
4 files changed, 10 insertions, 2 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index f91137d4a..c30fdf403 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -126,7 +126,7 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; /// Marks a function for backward-mode differentiation. __attributeTarget(FunctionDeclBase) -attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; +attribute_syntax [BackwardDifferentiable(order:int = 0)] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 59ac26833..1b829c836 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1184,6 +1184,7 @@ class ForwardDerivativeOfAttribute : public DerivativeOfAttribute class BackwardDifferentiableAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(BackwardDifferentiableAttribute) + int maxOrder = 0; }; /// The `[BackwardDerivative(function)]` attribute specifies a custom function that should diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index b1f36ca2b..2bc914a65 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -573,6 +573,13 @@ namespace Slang getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount()); } } + else if (auto diffAttr = as<BackwardDifferentiableAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto cint = checkConstantIntVal(attr->args[0]); + if (cint) + diffAttr->maxOrder = (int32_t)cint->value; + } else if (auto formatAttr = as<FormatAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop.slang index 5598f6b71..a2c826be9 100644 --- a/tests/autodiff/reverse-loop.slang +++ b/tests/autodiff/reverse-loop.slang @@ -8,7 +8,7 @@ RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; typedef float.Differential dfloat; -[BackwardDifferentiable] +[Differentiable] float test_simple_loop(float y) { float t = y; |
