diff options
| author | Yong He <yonghe@outlook.com> | 2023-05-14 15:09:22 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-14 15:09:22 -0700 |
| commit | 6ac0c6a688b33965ba83c18e68861f8f9c4f5250 (patch) | |
| tree | 04a0ab71ffd2b932a4c4ee1d3b3127507211d443 /source | |
| parent | cd6064201eb8443918054588002a442459113ed4 (diff) | |
Add [Differentiable(n)] syntax to specify max order. (#2883)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 7 |
3 files changed, 9 insertions, 1 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index f91137d4a..c30fdf403 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -126,7 +126,7 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; /// Marks a function for backward-mode differentiation. __attributeTarget(FunctionDeclBase) -attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; +attribute_syntax [BackwardDifferentiable(order:int = 0)] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 59ac26833..1b829c836 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1184,6 +1184,7 @@ class ForwardDerivativeOfAttribute : public DerivativeOfAttribute class BackwardDifferentiableAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(BackwardDifferentiableAttribute) + int maxOrder = 0; }; /// The `[BackwardDerivative(function)]` attribute specifies a custom function that should diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index b1f36ca2b..2bc914a65 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -573,6 +573,13 @@ namespace Slang getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount()); } } + else if (auto diffAttr = as<BackwardDifferentiableAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto cint = checkConstantIntVal(attr->args[0]); + if (cint) + diffAttr->maxOrder = (int32_t)cint->value; + } else if (auto formatAttr = as<FormatAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); |
