summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp169
1 files changed, 142 insertions, 27 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index f568dd8df..311a5944b 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -960,23 +960,13 @@ namespace Slang
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{
auto checkedTerm = _CheckTerm(term);
-
- // Differentiable type checking.
- // TODO: This can be super slow.
- if (this->m_parentFunc &&
- this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>())
- {
- maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
- }
-
// Differentiable type checking.
// TODO: This can be super slow.
if (this->m_parentFunc &&
- this->m_parentFunc->findModifier<BackwardDifferentiableAttribute>())
+ this->m_parentFunc->findModifier<DifferentiableAttribute>())
{
maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
}
-
return checkedTerm;
}
@@ -1060,6 +1050,10 @@ namespace Slang
{
return overloadedExpr->base;
}
+ else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr))
+ {
+ return overloadedExpr2->base;
+ }
return nullptr;
}
@@ -2009,7 +2003,7 @@ namespace Slang
return primalType;
}
- Type* SemanticsVisitor::processJVPFuncType(FuncType* originalType)
+ Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType)
{
// Resolve JVP type here.
// Note that this type checking needs to be in sync with
@@ -2035,7 +2029,7 @@ namespace Slang
return jvpType;
}
- Type* SemanticsVisitor::processBackwardDiffFuncType(FuncType* originalType)
+ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType)
{
// Resolve backward diff type here.
// Note that this type checking needs to be in sync with
@@ -2074,30 +2068,151 @@ namespace Slang
return type;
}
- Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
+ struct DifferentiateExprCheckingActions
{
- // Check/Resolve inner function declaration.
- expr->baseFunction = CheckTerm(expr->baseFunction);
+ virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) = 0;
+ virtual void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0;
+ FuncType* getBaseFunctionType(SemanticsVisitor* semantics, Expr* funcExpr)
+ {
+ if (auto funcType = as<FuncType>(funcExpr->type.type))
+ return funcType;
+ auto astBuilder = semantics->getASTBuilder();
+ if (auto declRefExpr = as<DeclRefExpr>(funcExpr))
+ {
+ if (auto baseFuncGenericDeclRef = declRefExpr->declRef.as<GenericDecl>())
+ {
+ // Get inner function
+ DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>(
+ getInner(baseFuncGenericDeclRef),
+ baseFuncGenericDeclRef.substitutions);
+ auto callableDeclRef = unspecializedInnerRef.as<CallableDecl>();
+ if (!callableDeclRef)
+ return nullptr;
+ auto funcType = getFuncType(astBuilder, callableDeclRef);
+ return funcType;
+ }
+ }
+ return nullptr;
+ }
+ };
- // For now we only support using higher order expr as callee in an invoke expr.
- // The actual type of the higher order function will be derived during resolve invoke.
- expr->type = m_astBuilder->getBottomType();
+ struct ForwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions
+ {
+ virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override
+ {
+ return semantics->getASTBuilder()->create<ForwardDifferentiateExpr>();
+ }
+ void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
+ {
+ resultDiffExpr->baseFunction = funcExpr;
+ auto baseFuncType = getBaseFunctionType(semantics, funcExpr);
+ if (!baseFuncType)
+ {
+ resultDiffExpr->type = semantics->getASTBuilder()->getErrorType();
+ semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type);
+ return;
+ }
+ resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType);
+ if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr)))
+ {
+ if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>())
+ {
+ for (auto param : funcDecl.getDecl()->getParameters())
+ {
+ resultDiffExpr->newParameterNames.add(param->getName());
+ }
+ }
+ }
+ }
+ };
- return expr;
- }
+ struct BackwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions
+ {
+ virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override
+ {
+ return semantics->getASTBuilder()->create<BackwardDifferentiateExpr>();
+ }
+ void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override
+ {
+ resultDiffExpr->baseFunction = funcExpr;
+ auto baseFuncType = getBaseFunctionType(semantics, funcExpr);
+ if (!baseFuncType)
+ {
+ resultDiffExpr->type = semantics->getASTBuilder()->getErrorType();
+ semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type);
+ }
+ resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType);
+ if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr)))
+ {
+ if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>())
+ {
+ for (auto param : funcDecl.getDecl()->getParameters())
+ {
+ resultDiffExpr->newParameterNames.add(param->getName());
+ }
+ }
+ resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient"));
+ }
+ }
+ };
- Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr)
+ static Expr* _checkDifferentiateExpr(
+ SemanticsVisitor* semantics,
+ DifferentiateExpr* expr,
+ DifferentiateExprCheckingActions* actions)
{
// Check/Resolve inner function declaration.
- expr->baseFunction = CheckTerm(expr->baseFunction);
+ expr->baseFunction = semantics->CheckTerm(expr->baseFunction);
- // For now we only support using higher order expr as callee in an invoke expr.
- // The actual type of the higher order function will be derived during resolve invoke.
- expr->type = m_astBuilder->getBottomType();
+ auto astBuilder = semantics->getASTBuilder();
+ // If base is overloaded expr, we want to return an overloaded expr as check result.
+ // This is done by pushing the `differentiate` operator to each item in the overloaded expr.
+ if (auto overloadedExpr = as<OverloadedExpr>(expr->baseFunction))
+ {
+ OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>();
+ for (auto item : overloadedExpr->lookupResult2)
+ {
+ auto lookupResultExpr = semantics->ConstructLookupResultExpr(item,
+ nullptr,
+ expr->loc,
+ nullptr);
+ auto candidateExpr = actions->createDifferentiateExpr(semantics);
+ actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr);
+ result->candidiateExprs.add(candidateExpr);
+ }
+ result->type.type = astBuilder->getOverloadedType();
+ return result;
+ }
+ else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr->baseFunction))
+ {
+ OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>();
+ for (auto item : overloadedExpr2->candidiateExprs)
+ {
+ auto candidateExpr = actions->createDifferentiateExpr(semantics);
+ actions->fillDifferentiateExpr(candidateExpr, semantics, item);
+ result->candidiateExprs.add(candidateExpr);
+ }
+ result->type.type = astBuilder->getOverloadedType();
+ return result;
+ }
+
+ actions->fillDifferentiateExpr(expr, semantics, expr->baseFunction);
return expr;
}
+ Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
+ {
+ ForwardDifferentiateExprCheckingActions actions;
+ return _checkDifferentiateExpr(this, expr, &actions);
+ }
+
+ Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr)
+ {
+ BackwardDifferentiateExprCheckingActions actions;
+ return _checkDifferentiateExpr(this, expr, &actions);
+ }
+
Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
{
expr->arrayExpr = CheckTerm(expr->arrayExpr);
@@ -2923,7 +3038,7 @@ namespace Slang
// because vectors are also declaration reference types...
//
// Also note: the way this is done right now means that the ability
- // to swizzle vectors interferes with any chance of looking up
+ // to swizzle vectors interferes with any chance o<f looking up
// members via extension, for vector or scalar types.
//
// TODO: Matrix swizzles probably need to be handled at some point.