diff options
| author | Yong He <yonghe@outlook.com> | 2023-05-12 16:15:36 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-12 16:15:36 -0700 |
| commit | cd6064201eb8443918054588002a442459113ed4 (patch) | |
| tree | 9dc5e60abe2aa8bca810c631fdd62758093992f5 /source | |
| parent | 65103bc9a0c72117d3c9410e361947cdd568ae55 (diff) | |
Add finalized keywords for autodiff. (#2882)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 2 |
3 files changed, 7 insertions, 2 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 730c3fcc8..f91137d4a 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -127,6 +127,9 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; /// Marks a function for backward-mode differentiation. __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribute; + /// Interface to denote types as differentiable. /// Allows for user-specified differential types as diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index 6e3c326fb..b446221ff 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -62,9 +62,9 @@ Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLeve UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr) { if (as<ForwardDifferentiateExpr>(expr)) - return UnownedStringSlice("__fwd_diff"); + return UnownedStringSlice("fwd_diff"); else if (as<BackwardDifferentiateExpr>(expr)) - return UnownedStringSlice("__bwd_diff"); + return UnownedStringSlice("bwd_diff"); return UnownedStringSlice(); } } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index cce4b7e7b..410395669 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -6850,6 +6850,8 @@ namespace Slang _makeParseExpr("__TaggedUnion", parseTaggedUnionType), _makeParseExpr("__fwd_diff", parseForwardDifferentiate), _makeParseExpr("__bwd_diff", parseBackwardDifferentiate), + _makeParseExpr("fwd_diff", parseForwardDifferentiate), + _makeParseExpr("bwd_diff", parseBackwardDifferentiate), _makeParseExpr("__dispatch_kernel", parseDispatchKernel) }; |
