From 6ac0c6a688b33965ba83c18e68861f8f9c4f5250 Mon Sep 17 00:00:00 2001 From: Yong He Date: Sun, 14 May 2023 15:09:22 -0700 Subject: Add [Differentiable(n)] syntax to specify max order. (#2883) --- source/slang/core.meta.slang | 2 +- source/slang/slang-ast-modifier.h | 1 + source/slang/slang-check-modifier.cpp | 7 +++++++ 3 files changed, 9 insertions(+), 1 deletion(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index f91137d4a..c30fdf403 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -126,7 +126,7 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; /// Marks a function for backward-mode differentiation. __attributeTarget(FunctionDeclBase) -attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; +attribute_syntax [BackwardDifferentiable(order:int = 0)] : BackwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 59ac26833..1b829c836 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1184,6 +1184,7 @@ class ForwardDerivativeOfAttribute : public DerivativeOfAttribute class BackwardDifferentiableAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(BackwardDifferentiableAttribute) + int maxOrder = 0; }; /// The `[BackwardDerivative(function)]` attribute specifies a custom function that should diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index b1f36ca2b..2bc914a65 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -573,6 +573,13 @@ namespace Slang getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), params.getCount()); } } + else if (auto diffAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + auto cint = checkConstantIntVal(attr->args[0]); + if (cint) + diffAttr->maxOrder = (int32_t)cint->value; + } else if (auto formatAttr = as(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); -- cgit v1.2.3