From cd6064201eb8443918054588002a442459113ed4 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 12 May 2023 16:15:36 -0700 Subject: Add finalized keywords for autodiff. (#2882) Co-authored-by: Yong He --- source/slang/core.meta.slang | 3 +++ source/slang/slang-ast-support-types.cpp | 4 ++-- source/slang/slang-parser.cpp | 2 ++ 3 files changed, 7 insertions(+), 2 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 730c3fcc8..f91137d4a 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -127,6 +127,9 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; /// Marks a function for backward-mode differentiation. __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribute; + /// Interface to denote types as differentiable. /// Allows for user-specified differential types as diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index 6e3c326fb..b446221ff 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -62,9 +62,9 @@ Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLeve UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr) { if (as(expr)) - return UnownedStringSlice("__fwd_diff"); + return UnownedStringSlice("fwd_diff"); else if (as(expr)) - return UnownedStringSlice("__bwd_diff"); + return UnownedStringSlice("bwd_diff"); return UnownedStringSlice(); } } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index cce4b7e7b..410395669 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -6850,6 +6850,8 @@ namespace Slang _makeParseExpr("__TaggedUnion", parseTaggedUnionType), _makeParseExpr("__fwd_diff", parseForwardDifferentiate), _makeParseExpr("__bwd_diff", parseBackwardDifferentiate), + _makeParseExpr("fwd_diff", parseForwardDifferentiate), + _makeParseExpr("bwd_diff", parseBackwardDifferentiate), _makeParseExpr("__dispatch_kernel", parseDispatchKernel) }; -- cgit v1.2.3