summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-parser.cpp
diff options
context:
space:
mode:
authorEdward Liu <shiqiu1105@gmail.com>2022-11-14 12:08:01 -0800
committerGitHub <noreply@github.com>2022-11-14 12:08:01 -0800
commit368ec3116ea0f10f44acbf76b5dc9e34d6ff3d32 (patch)
tree3d9def111db278affb8413bddb5aab9ce3cf73a6 /source/slang/slang-parser.cpp
parent623f5c36e0dc8190753aa5fa2e89f1010c367c67 (diff)
Minimum binary arithmetic reverse autodiff working. (#2514)
* Initial plumbing of backward autodiff in the frontend. * More plumbing. * Initial reverse autodiff working. * Bug fixes. * Misc. * Remove redundant code. * More clean up. * Misc. * Rebase and add backward diff test. * Disable test. * Clean up. * Minor fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-parser.cpp')
-rw-r--r--source/slang/slang-parser.cpp23
1 files changed, 22 insertions, 1 deletions
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 78edd4deb..d3dc5964e 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2109,6 +2109,26 @@ namespace Slang
return parseForwardDifferentiate(parser);
}
+ /// Parse an expression of the form __bwd_diff(fn) where fn is an
+ /// identifier pointing to a function.
+ static Expr* parseBackwardDifferentiate(Parser* parser)
+ {
+ BackwardDifferentiateExpr* bwdDiffExpr = parser->astBuilder->create<BackwardDifferentiateExpr>();
+
+ parser->ReadToken(TokenType::LParent);
+
+ bwdDiffExpr->baseFunction = parser->ParseExpression();
+
+ parser->ReadToken(TokenType::RParent);
+
+ return bwdDiffExpr;
+ }
+
+ static NodeBase* parseBackwardDifferentiate(Parser* parser, void* /* unused */)
+ {
+ return parseBackwardDifferentiate(parser);
+ }
+
/// Parse a `This` type expression
static Expr* parseThisTypeExpr(Parser* parser)
{
@@ -6646,7 +6666,8 @@ namespace Slang
_makeParseExpr("none", parseNoneExpr),
_makeParseExpr("try", parseTryExpr),
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
- _makeParseExpr("__fwd_diff", parseForwardDifferentiate)
+ _makeParseExpr("__fwd_diff", parseForwardDifferentiate),
+ _makeParseExpr("__bwd_diff", parseBackwardDifferentiate)
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()