From 8da47c460df01fad6f1d0614210a770f4781edb1 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Sat, 25 Jun 2022 15:45:34 -0400 Subject: Added basic auto-diff capabilities for local load/store and simple arithmetic. Also added type-checking during the semantic stage. (#2303) * Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types * Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly * Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated Co-authored-by: Yong He --- source/slang/slang-parser.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'source/slang/slang-parser.cpp') diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index ee34eac6f..d168bf55c 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2058,22 +2058,22 @@ namespace Slang } /// Parse an expression of the form __jvp(fn) where fn is an /// identifier pointing to a function. - static Expr* parseJVPDerivativeOf(Parser* parser) + static Expr* parseJVPDifferentiate(Parser* parser) { - JVPDerivativeOfExpr* jvpExpr = parser->astBuilder->create(); + JVPDifferentiateExpr* jvpExpr = parser->astBuilder->create(); parser->ReadToken(TokenType::LParent); - jvpExpr->baseFn = parser->ParseExpression(); + jvpExpr->baseFunction = parser->ParseExpression(); parser->ReadToken(TokenType::RParent); return jvpExpr; } - static NodeBase* parseJVPDerivativeOf(Parser* parser, void* /* unused */) + static NodeBase* parseJVPDifferentiate(Parser* parser, void* /* unused */) { - return parseJVPDerivativeOf(parser); + return parseJVPDifferentiate(parser); } /// Parse a `This` type expression @@ -6492,7 +6492,7 @@ namespace Slang _makeParseExpr("nullptr", parseNullPtrExpr), _makeParseExpr("try", parseTryExpr), _makeParseExpr("__TaggedUnion", parseTaggedUnionType), - _makeParseExpr("__jvp", parseJVPDerivativeOf) + _makeParseExpr("__jvp", parseJVPDifferentiate) }; ConstArrayView getSyntaxParseInfos() -- cgit v1.2.3