diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-06-25 15:45:34 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-25 12:45:34 -0700 |
| commit | 8da47c460df01fad6f1d0614210a770f4781edb1 (patch) | |
| tree | 170a5cc100c69e387e8c6d34217588ea00daed53 /source/slang/slang-parser.cpp | |
| parent | 0229784b93a43e17a088881e6be32b44fc6ce713 (diff) | |
Added basic auto-diff capabilities for local load/store and simple arithmetic. Also added type-checking during the semantic stage. (#2303)
* Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types
* Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly
* Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-parser.cpp')
| -rw-r--r-- | source/slang/slang-parser.cpp | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index ee34eac6f..d168bf55c 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2058,22 +2058,22 @@ namespace Slang } /// Parse an expression of the form __jvp(fn) where fn is an /// identifier pointing to a function. - static Expr* parseJVPDerivativeOf(Parser* parser) + static Expr* parseJVPDifferentiate(Parser* parser) { - JVPDerivativeOfExpr* jvpExpr = parser->astBuilder->create<JVPDerivativeOfExpr>(); + JVPDifferentiateExpr* jvpExpr = parser->astBuilder->create<JVPDifferentiateExpr>(); parser->ReadToken(TokenType::LParent); - jvpExpr->baseFn = parser->ParseExpression(); + jvpExpr->baseFunction = parser->ParseExpression(); parser->ReadToken(TokenType::RParent); return jvpExpr; } - static NodeBase* parseJVPDerivativeOf(Parser* parser, void* /* unused */) + static NodeBase* parseJVPDifferentiate(Parser* parser, void* /* unused */) { - return parseJVPDerivativeOf(parser); + return parseJVPDifferentiate(parser); } /// Parse a `This` type expression @@ -6492,7 +6492,7 @@ namespace Slang _makeParseExpr("nullptr", parseNullPtrExpr), _makeParseExpr("try", parseTryExpr), _makeParseExpr("__TaggedUnion", parseTaggedUnionType), - _makeParseExpr("__jvp", parseJVPDerivativeOf) + _makeParseExpr("__jvp", parseJVPDifferentiate) }; ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos() |
