summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-05-12 16:15:36 -0700
committerGitHub <noreply@github.com>2023-05-12 16:15:36 -0700
commitcd6064201eb8443918054588002a442459113ed4 (patch)
tree9dc5e60abe2aa8bca810c631fdd62758093992f5 /source
parent65103bc9a0c72117d3c9410e361947cdd568ae55 (diff)
Add finalized keywords for autodiff. (#2882)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/slang-ast-support-types.cpp4
-rw-r--r--source/slang/slang-parser.cpp2
3 files changed, 7 insertions, 2 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 730c3fcc8..f91137d4a 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -127,6 +127,9 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
/// Marks a function for backward-mode differentiation.
__attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribute;
+
/// Interface to denote types as differentiable.
/// Allows for user-specified differential types as
diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp
index 6e3c326fb..b446221ff 100644
--- a/source/slang/slang-ast-support-types.cpp
+++ b/source/slang/slang-ast-support-types.cpp
@@ -62,9 +62,9 @@ Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLeve
UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr)
{
if (as<ForwardDifferentiateExpr>(expr))
- return UnownedStringSlice("__fwd_diff");
+ return UnownedStringSlice("fwd_diff");
else if (as<BackwardDifferentiateExpr>(expr))
- return UnownedStringSlice("__bwd_diff");
+ return UnownedStringSlice("bwd_diff");
return UnownedStringSlice();
}
}
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index cce4b7e7b..410395669 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -6850,6 +6850,8 @@ namespace Slang
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
_makeParseExpr("__fwd_diff", parseForwardDifferentiate),
_makeParseExpr("__bwd_diff", parseBackwardDifferentiate),
+ _makeParseExpr("fwd_diff", parseForwardDifferentiate),
+ _makeParseExpr("bwd_diff", parseBackwardDifferentiate),
_makeParseExpr("__dispatch_kernel", parseDispatchKernel)
};