diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-08 21:52:34 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-08 21:52:34 -0800 |
| commit | 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch) | |
| tree | b4f9eb6cb1eea88145fde0bd1f670a8803120257 /source/slang/slang-check-expr.cpp | |
| parent | 257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff) | |
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`.
* Fix
* Fix.
* Cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 79 |
1 files changed, 60 insertions, 19 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index bebaa63a2..f749361d7 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2227,10 +2227,10 @@ namespace Slang return type; } - struct DifferentiateExprCheckingActions + struct HigherOrderInvokeExprCheckingActions { - virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) = 0; - virtual void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0; + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) = 0; + virtual void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0; FuncType* getBaseFunctionType(SemanticsVisitor* semantics, Expr* funcExpr) { if (auto funcType = as<FuncType>(funcExpr->type.type)) @@ -2255,13 +2255,13 @@ namespace Slang } }; - struct ForwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions + struct ForwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions { - virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override { return semantics->getASTBuilder()->create<ForwardDifferentiateExpr>(); } - void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override { resultDiffExpr->baseFunction = funcExpr; auto baseFuncType = getBaseFunctionType(semantics, funcExpr); @@ -2290,13 +2290,13 @@ namespace Slang } }; - struct BackwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions + struct BackwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions { - virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override { return semantics->getASTBuilder()->create<BackwardDifferentiateExpr>(); } - void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override { resultDiffExpr->baseFunction = funcExpr; auto baseFuncType = getBaseFunctionType(semantics, funcExpr); @@ -2333,10 +2333,45 @@ namespace Slang } }; - static Expr* _checkDifferentiateExpr( + struct PrimalSubstituteExprCheckingActions : HigherOrderInvokeExprCheckingActions + { + virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override + { + return semantics->getASTBuilder()->create<PrimalSubstituteExpr>(); + } + void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + { + resultDiffExpr->baseFunction = funcExpr; + auto baseFuncType = getBaseFunctionType(semantics, funcExpr); + if (!baseFuncType) + { + resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); + semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); + return; + } + resultDiffExpr->type = baseFuncType; + if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) + { + auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl(); + if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl())) + { + funcDecl = as<CallableDecl>(genDecl->inner); + } + if (funcDecl) + { + for (auto param : funcDecl->getParameters()) + { + resultDiffExpr->newParameterNames.add(param->getName()); + } + } + } + } + }; + + static Expr* _checkHigherOrderInvokeExpr( SemanticsVisitor* semantics, - DifferentiateExpr* expr, - DifferentiateExprCheckingActions* actions) + HigherOrderInvokeExpr* expr, + HigherOrderInvokeExprCheckingActions* actions) { // Check/Resolve inner function declaration. expr->baseFunction = semantics->CheckTerm(expr->baseFunction); @@ -2354,8 +2389,8 @@ namespace Slang nullptr, overloadedExpr->loc, nullptr); - auto candidateExpr = actions->createDifferentiateExpr(semantics); - actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr); + auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); + actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, lookupResultExpr); candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } @@ -2368,8 +2403,8 @@ namespace Slang OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>(); for (auto item : overloadedExpr2->candidiateExprs) { - auto candidateExpr = actions->createDifferentiateExpr(semantics); - actions->fillDifferentiateExpr(candidateExpr, semantics, item); + auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); + actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, item); candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } @@ -2378,20 +2413,26 @@ namespace Slang return result; } - actions->fillDifferentiateExpr(expr, semantics, expr->baseFunction); + actions->fillHigherOrderInvokeExpr(expr, semantics, expr->baseFunction); return expr; } Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { ForwardDifferentiateExprCheckingActions actions; - return _checkDifferentiateExpr(this, expr, &actions); + return _checkHigherOrderInvokeExpr(this, expr, &actions); } Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) { BackwardDifferentiateExprCheckingActions actions; - return _checkDifferentiateExpr(this, expr, &actions); + return _checkHigherOrderInvokeExpr(this, expr, &actions); + } + + Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) + { + PrimalSubstituteExprCheckingActions actions; + return _checkHigherOrderInvokeExpr(this, expr, &actions); } Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) |
