summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp79
1 files changed, 60 insertions, 19 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index bebaa63a2..f749361d7 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2227,10 +2227,10 @@ namespace Slang
return type;
}
- struct DifferentiateExprCheckingActions
+ struct HigherOrderInvokeExprCheckingActions
{
- virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) = 0;
- virtual void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0;
+ virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) = 0;
+ virtual void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0;
FuncType* getBaseFunctionType(SemanticsVisitor* semantics, Expr* funcExpr)
{
if (auto funcType = as<FuncType>(funcExpr->type.type))
@@ -2255,13 +2255,13 @@ namespace Slang
}
};
- struct ForwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions
+ struct ForwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions
{
- virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override
+ virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override
{
return semantics->getASTBuilder()->create<ForwardDifferentiateExpr>();
}
- void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
+ void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
{
resultDiffExpr->baseFunction = funcExpr;
auto baseFuncType = getBaseFunctionType(semantics, funcExpr);
@@ -2290,13 +2290,13 @@ namespace Slang
}
};
- struct BackwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions
+ struct BackwardDifferentiateExprCheckingActions : HigherOrderInvokeExprCheckingActions
{
- virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override
+ virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override
{
return semantics->getASTBuilder()->create<BackwardDifferentiateExpr>();
}
- void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
+ void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
{
resultDiffExpr->baseFunction = funcExpr;
auto baseFuncType = getBaseFunctionType(semantics, funcExpr);
@@ -2333,10 +2333,45 @@ namespace Slang
}
};
- static Expr* _checkDifferentiateExpr(
+ struct PrimalSubstituteExprCheckingActions : HigherOrderInvokeExprCheckingActions
+ {
+ virtual HigherOrderInvokeExpr* createHigherOrderInvokeExpr(SemanticsVisitor* semantics) override
+ {
+ return semantics->getASTBuilder()->create<PrimalSubstituteExpr>();
+ }
+ void fillHigherOrderInvokeExpr(HigherOrderInvokeExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
+ {
+ resultDiffExpr->baseFunction = funcExpr;
+ auto baseFuncType = getBaseFunctionType(semantics, funcExpr);
+ if (!baseFuncType)
+ {
+ resultDiffExpr->type = semantics->getASTBuilder()->getErrorType();
+ semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type);
+ return;
+ }
+ resultDiffExpr->type = baseFuncType;
+ if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr)))
+ {
+ auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl();
+ if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl()))
+ {
+ funcDecl = as<CallableDecl>(genDecl->inner);
+ }
+ if (funcDecl)
+ {
+ for (auto param : funcDecl->getParameters())
+ {
+ resultDiffExpr->newParameterNames.add(param->getName());
+ }
+ }
+ }
+ }
+ };
+
+ static Expr* _checkHigherOrderInvokeExpr(
SemanticsVisitor* semantics,
- DifferentiateExpr* expr,
- DifferentiateExprCheckingActions* actions)
+ HigherOrderInvokeExpr* expr,
+ HigherOrderInvokeExprCheckingActions* actions)
{
// Check/Resolve inner function declaration.
expr->baseFunction = semantics->CheckTerm(expr->baseFunction);
@@ -2354,8 +2389,8 @@ namespace Slang
nullptr,
overloadedExpr->loc,
nullptr);
- auto candidateExpr = actions->createDifferentiateExpr(semantics);
- actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr);
+ auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics);
+ actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, lookupResultExpr);
candidateExpr->loc = expr->loc;
result->candidiateExprs.add(candidateExpr);
}
@@ -2368,8 +2403,8 @@ namespace Slang
OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>();
for (auto item : overloadedExpr2->candidiateExprs)
{
- auto candidateExpr = actions->createDifferentiateExpr(semantics);
- actions->fillDifferentiateExpr(candidateExpr, semantics, item);
+ auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics);
+ actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, item);
candidateExpr->loc = expr->loc;
result->candidiateExprs.add(candidateExpr);
}
@@ -2378,20 +2413,26 @@ namespace Slang
return result;
}
- actions->fillDifferentiateExpr(expr, semantics, expr->baseFunction);
+ actions->fillHigherOrderInvokeExpr(expr, semantics, expr->baseFunction);
return expr;
}
Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
{
ForwardDifferentiateExprCheckingActions actions;
- return _checkDifferentiateExpr(this, expr, &actions);
+ return _checkHigherOrderInvokeExpr(this, expr, &actions);
}
Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr)
{
BackwardDifferentiateExprCheckingActions actions;
- return _checkDifferentiateExpr(this, expr, &actions);
+ return _checkHigherOrderInvokeExpr(this, expr, &actions);
+ }
+
+ Expr* SemanticsExprVisitor::visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr)
+ {
+ PrimalSubstituteExprCheckingActions actions;
+ return _checkHigherOrderInvokeExpr(this, expr, &actions);
}
Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)