summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-05-14 15:09:22 -0700
committerGitHub <noreply@github.com>2023-05-14 15:09:22 -0700
commit6ac0c6a688b33965ba83c18e68861f8f9c4f5250 (patch)
tree04a0ab71ffd2b932a4c4ee1d3b3127507211d443
parentcd6064201eb8443918054588002a442459113ed4 (diff)
Add [Differentiable(n)] syntax to specify max order. (#2883)
-rw-r--r--source/slang/core.meta.slang2
-rw-r--r--source/slang/slang-ast-modifier.h1
-rw-r--r--source/slang/slang-check-modifier.cpp7
-rw-r--r--tests/autodiff/reverse-loop.slang2
4 files changed, 10 insertions, 2 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index f91137d4a..c30fdf403 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -126,7 +126,7 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
/// Marks a function for backward-mode differentiation.
__attributeTarget(FunctionDeclBase)
-attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute;
+attribute_syntax [BackwardDifferentiable(order:int = 0)] : BackwardDifferentiableAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribute;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 59ac26833..1b829c836 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1184,6 +1184,7 @@ class ForwardDerivativeOfAttribute : public DerivativeOfAttribute
class BackwardDifferentiableAttribute : public DifferentiableAttribute
{
SLANG_AST_CLASS(BackwardDifferentiableAttribute)
+ int maxOrder = 0;
};
/// The `[BackwardDerivative(function)]` attribute specifies a custom function that should
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index b1f36ca2b..2bc914a65 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -573,6 +573,13 @@ namespace Slang
getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount());
}
}
+ else if (auto diffAttr = as<BackwardDifferentiableAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+ auto cint = checkConstantIntVal(attr->args[0]);
+ if (cint)
+ diffAttr->maxOrder = (int32_t)cint->value;
+ }
else if (auto formatAttr = as<FormatAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop.slang
index 5598f6b71..a2c826be9 100644
--- a/tests/autodiff/reverse-loop.slang
+++ b/tests/autodiff/reverse-loop.slang
@@ -8,7 +8,7 @@ RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
typedef float.Differential dfloat;
-[BackwardDifferentiable]
+[Differentiable]
float test_simple_loop(float y)
{
float t = y;