diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 169 |
1 files changed, 142 insertions, 27 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f568dd8df..311a5944b 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -960,23 +960,13 @@ namespace Slang Expr* SemanticsVisitor::CheckTerm(Expr* term) { auto checkedTerm = _CheckTerm(term); - - // Differentiable type checking. - // TODO: This can be super slow. - if (this->m_parentFunc && - this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>()) - { - maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); - } - // Differentiable type checking. // TODO: This can be super slow. if (this->m_parentFunc && - this->m_parentFunc->findModifier<BackwardDifferentiableAttribute>()) + this->m_parentFunc->findModifier<DifferentiableAttribute>()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); } - return checkedTerm; } @@ -1060,6 +1050,10 @@ namespace Slang { return overloadedExpr->base; } + else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr)) + { + return overloadedExpr2->base; + } return nullptr; } @@ -2009,7 +2003,7 @@ namespace Slang return primalType; } - Type* SemanticsVisitor::processJVPFuncType(FuncType* originalType) + Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) { // Resolve JVP type here. // Note that this type checking needs to be in sync with @@ -2035,7 +2029,7 @@ namespace Slang return jvpType; } - Type* SemanticsVisitor::processBackwardDiffFuncType(FuncType* originalType) + Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType) { // Resolve backward diff type here. // Note that this type checking needs to be in sync with @@ -2074,30 +2068,151 @@ namespace Slang return type; } - Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) + struct DifferentiateExprCheckingActions { - // Check/Resolve inner function declaration. - expr->baseFunction = CheckTerm(expr->baseFunction); + virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) = 0; + virtual void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0; + FuncType* getBaseFunctionType(SemanticsVisitor* semantics, Expr* funcExpr) + { + if (auto funcType = as<FuncType>(funcExpr->type.type)) + return funcType; + auto astBuilder = semantics->getASTBuilder(); + if (auto declRefExpr = as<DeclRefExpr>(funcExpr)) + { + if (auto baseFuncGenericDeclRef = declRefExpr->declRef.as<GenericDecl>()) + { + // Get inner function + DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>( + getInner(baseFuncGenericDeclRef), + baseFuncGenericDeclRef.substitutions); + auto callableDeclRef = unspecializedInnerRef.as<CallableDecl>(); + if (!callableDeclRef) + return nullptr; + auto funcType = getFuncType(astBuilder, callableDeclRef); + return funcType; + } + } + return nullptr; + } + }; - // For now we only support using higher order expr as callee in an invoke expr. - // The actual type of the higher order function will be derived during resolve invoke. - expr->type = m_astBuilder->getBottomType(); + struct ForwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions + { + virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override + { + return semantics->getASTBuilder()->create<ForwardDifferentiateExpr>(); + } + void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + { + resultDiffExpr->baseFunction = funcExpr; + auto baseFuncType = getBaseFunctionType(semantics, funcExpr); + if (!baseFuncType) + { + resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); + semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); + return; + } + resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType); + if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) + { + if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>()) + { + for (auto param : funcDecl.getDecl()->getParameters()) + { + resultDiffExpr->newParameterNames.add(param->getName()); + } + } + } + } + }; - return expr; - } + struct BackwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions + { + virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override + { + return semantics->getASTBuilder()->create<BackwardDifferentiateExpr>(); + } + void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + { + resultDiffExpr->baseFunction = funcExpr; + auto baseFuncType = getBaseFunctionType(semantics, funcExpr); + if (!baseFuncType) + { + resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); + semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); + } + resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType); + if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) + { + if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>()) + { + for (auto param : funcDecl.getDecl()->getParameters()) + { + resultDiffExpr->newParameterNames.add(param->getName()); + } + } + resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); + } + } + }; - Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + static Expr* _checkDifferentiateExpr( + SemanticsVisitor* semantics, + DifferentiateExpr* expr, + DifferentiateExprCheckingActions* actions) { // Check/Resolve inner function declaration. - expr->baseFunction = CheckTerm(expr->baseFunction); + expr->baseFunction = semantics->CheckTerm(expr->baseFunction); - // For now we only support using higher order expr as callee in an invoke expr. - // The actual type of the higher order function will be derived during resolve invoke. - expr->type = m_astBuilder->getBottomType(); + auto astBuilder = semantics->getASTBuilder(); + // If base is overloaded expr, we want to return an overloaded expr as check result. + // This is done by pushing the `differentiate` operator to each item in the overloaded expr. + if (auto overloadedExpr = as<OverloadedExpr>(expr->baseFunction)) + { + OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>(); + for (auto item : overloadedExpr->lookupResult2) + { + auto lookupResultExpr = semantics->ConstructLookupResultExpr(item, + nullptr, + expr->loc, + nullptr); + auto candidateExpr = actions->createDifferentiateExpr(semantics); + actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr); + result->candidiateExprs.add(candidateExpr); + } + result->type.type = astBuilder->getOverloadedType(); + return result; + } + else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr->baseFunction)) + { + OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>(); + for (auto item : overloadedExpr2->candidiateExprs) + { + auto candidateExpr = actions->createDifferentiateExpr(semantics); + actions->fillDifferentiateExpr(candidateExpr, semantics, item); + result->candidiateExprs.add(candidateExpr); + } + result->type.type = astBuilder->getOverloadedType(); + return result; + } + + actions->fillDifferentiateExpr(expr, semantics, expr->baseFunction); return expr; } + Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) + { + ForwardDifferentiateExprCheckingActions actions; + return _checkDifferentiateExpr(this, expr, &actions); + } + + Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + { + BackwardDifferentiateExprCheckingActions actions; + return _checkDifferentiateExpr(this, expr, &actions); + } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { expr->arrayExpr = CheckTerm(expr->arrayExpr); @@ -2923,7 +3038,7 @@ namespace Slang // because vectors are also declaration reference types... // // Also note: the way this is done right now means that the ability - // to swizzle vectors interferes with any chance of looking up + // to swizzle vectors interferes with any chance o<f looking up // members via extension, for vector or scalar types. // // TODO: Matrix swizzles probably need to be handled at some point. |
