diff options
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 614 |
1 files changed, 346 insertions, 268 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7c42c1892..5cd7fba45 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -34,6 +34,26 @@ namespace Slang } }; + struct SemanticsDeclAttributesVisitor + : public SemanticsDeclVisitorBase + , public DeclVisitor<SemanticsDeclAttributesVisitor> + { + SemanticsDeclAttributesVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + void visitFunctionDeclBase(FunctionDeclBase* decl); + + void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr); + + void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr); + + void checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr); + }; + struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase , public DeclVisitor<SemanticsDeclHeaderVisitor> @@ -258,10 +278,6 @@ namespace Slang void visitFunctionDeclBase(FunctionDeclBase* funcDecl); void visitParamDecl(ParamDecl* paramDecl); - - void checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl); - - void checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl); }; /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? @@ -4657,270 +4673,10 @@ namespace Slang getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); } - template<typename TDerivativeAttr> - void checkDerivativeAttributeImpl( - SemanticsVisitor* visitor, - TDerivativeAttr* attr, - const List<Expr*>& imaginaryArguments) - { - auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor); - auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as<InvokeExpr>(resolved)) - { - if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) - { - attr->funcExpr = calleeDeclRef; - return; - } - } - visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); - } - - template<typename TDerivativeAttr> - const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); } - - template<> - const char* getDerivativeAttrName<ForwardDerivativeAttribute>() - { - return "ForwardDerivative"; - } - template<> - const char* getDerivativeAttrName<BackwardDerivativeAttribute>() - { - return "BackwardDerivative"; - } - - List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) - { - List<Expr*> imaginaryArguments; - for (auto param : func->getParameters()) - { - auto arg = astBuilder->create<VarExpr>(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - imaginaryArguments.add(arg); - } - return imaginaryArguments; - } - - List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) - { - List<Expr*> imaginaryArguments; - for (auto param : originalFuncDecl->getParameters()) - { - auto arg = visitor->getASTBuilder()->create<VarExpr>(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) - { - arg->type.type = pairType; - } - imaginaryArguments.add(arg); - } - return imaginaryArguments; - } - - List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) - { - List<Expr*> imaginaryArguments; - auto isOutParam = [&](ParamDecl* param) - { - return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr; - }; - - for (auto param : originalFuncDecl->getParameters()) - { - auto arg = visitor->getASTBuilder()->create<VarExpr>(); - arg->declRef.decl = param; - arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; - arg->type.type = param->getType(); - arg->loc = loc; - if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType()))) - { - arg->type.type = pairType; - if (isOutParam(param)) - { - // out T -> in T.Differential - arg->type.isLeftValue = false; - arg->type.type = visitor->tryGetDifferentialType( - visitor->getASTBuilder(), pairType->getPrimalType()); - } - } - else - { - if (isOutParam(param)) - { - // Skip non-differentiable out params. - continue; - } - } - imaginaryArguments.add(arg); - } - if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type)) - { - auto arg = visitor->getASTBuilder()->create<InitializerListExpr>(); - arg->type.isLeftValue = false; - arg->type.type = diffReturnType; - arg->loc = loc; - imaginaryArguments.add(arg); - } - return imaginaryArguments; - } - - // This helper function is needed to workaround a gcc bug. - // Remove when we upgrade to a newer version of gcc. - template <typename T> - static T* _findModifier(Decl* decl) - { - return decl->findModifier<T>(); - } - - template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr> - void checkDerivativeOfAttributeImpl( - SemanticsVisitor* visitor, - FunctionDeclBase* funcDecl, - TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind) - { - DeclRef<Decl> calleeDeclRef; - DeclRefExpr* calleeDeclRefExpr = nullptr; - DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>(); - diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr; - diffFuncExpr->loc = derivativeOfAttr->loc; - Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor); - if (!checkedDiffFuncExpr) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); - auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as<InvokeExpr>(resolved)) - { - auto resolvedDiffFuncExpr = as<DifferentiateExpr>(resolvedInvoke->functionExpr); - if (resolvedDiffFuncExpr) - calleeDeclRefExpr = as<DeclRefExpr>(resolvedDiffFuncExpr->baseFunction); - } - - if (!calleeDeclRefExpr) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - calleeDeclRef = calleeDeclRefExpr->declRef; - - auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl()); - if (!calleeFunc) - { - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); - return; - } - - if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc)) - { - // The primal function already has a `[*Derivative]` attribute, this is invalid. - visitor->getSink()->diagnose( - derivativeOfAttr, - Diagnostics::declAlreadyHasAttribute, - calleeDeclRef, - getDerivativeAttrName<TDerivativeAttr>()); - visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl()); - } - derivativeOfAttr->funcExpr = calleeDeclRefExpr; - auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>(); - derivativeAttr->loc = derivativeOfAttr->loc; - auto outterGeneric = visitor->GetOuterGeneric(funcDecl); - auto declRef = - DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); - auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); - declRefExpr->type.type = nullptr; - derivativeAttr->args.add(declRefExpr); - derivativeAttr->funcExpr = declRefExpr; - checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr); - derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; - derivativeAttr->funcExpr = nullptr; - visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl); - } - - static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) - { - if (!attr->funcExpr) - return; - if (attr->funcExpr->type.type) - return; - - List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); - } - - static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) - { - if (!attr->funcExpr) - return; - if (attr->funcExpr->type.type) - return; - - List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); - checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); - } - - template<typename TDerivativeAttr, typename TDerivativeOfAttr> - bool tryCheckDerivativeOfAttributeImpl( - SemanticsVisitor* visitor, - FunctionDeclBase* funcDecl, - TDerivativeOfAttr* derivativeOfAttr, - DeclAssociationKind assocKind, - const List<Expr*>& imaginaryArgsToOriginal) - { - DiagnosticSink tempSink(visitor->getSourceManager(), nullptr); - SemanticsVisitor subVisitor(visitor->withSink(&tempSink)); - checkDerivativeOfAttributeImpl<TDerivativeAttr>( - &subVisitor, - funcDecl, - derivativeOfAttr, - assocKind, - imaginaryArgsToOriginal); - return tempSink.getErrorCount() == 0; - } - - void SemanticsDeclBodyVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) - { - auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>(); - if (!attr) - return; - - checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>( - this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); - } - - void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl) - { - auto attr = funcDecl->findModifier<BackwardDerivativeOfAttribute>(); - if (!attr) - return; - - checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>( - this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); - } - void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); - // Run checking on attributes that can't be fully checked in header checking stage. - checkForwardDerivativeOfAttribute(decl); - if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>()) - checkDerivativeAttribute(this, decl, derivativeAttr); - checkBackwardDerivativeOfAttribute(decl); - if (auto derivativeAttr = decl->findModifier<BackwardDerivativeAttribute>()) - checkDerivativeAttribute(this, decl, derivativeAttr); - if (newContext.getParentDifferentiableAttribute()) { // Register additional types outside the function body first. @@ -6762,7 +6518,7 @@ namespace Slang /// Note: this function creates an empty list of candidates for the given type if /// a matching entry doesn't exist already. /// - static List<DeclAssociation>& _getDeclAssociationList( + static List<RefPtr<DeclAssociation>>& _getDeclAssociationList( Decl* decl, OrderedDictionary<Decl*, RefPtr<DeclAssociationList>>& mapDeclToDeclarations) { @@ -6787,14 +6543,16 @@ namespace Slang void SharedSemanticsContext::registerAssociatedDecl(Decl* original, DeclAssociationKind kind, Decl* associated) { auto moduleDecl = getModuleDecl(associated); - DeclAssociation assoc = {kind, associated}; + RefPtr<DeclAssociation> assoc = new DeclAssociation(); + assoc->kind = kind; + assoc->decl = associated; _getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc); m_associatedDeclListsBuilt = false; m_mapDeclToAssociatedDecls.Clear(); } - List<DeclAssociation> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl) + List<RefPtr<DeclAssociation>> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl) { // This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`. // Consider refactoring them into the same framework. @@ -6838,6 +6596,23 @@ namespace Slang FunctionDifferentiableLevel SharedSemanticsContext::getFuncDifferentiableLevel(FunctionDeclBase* func) { + return _getFuncDifferentiableLevelImpl(func, 1); + } + + FunctionDifferentiableLevel SharedSemanticsContext::_getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit) + { + if (recurseLimit > 0) + { + if (auto primalSubst = func->findModifier<PrimalSubstituteAttribute>()) + { + if (auto declRefExpr = as<DeclRefExpr>(primalSubst->funcExpr)) + { + if (auto primalSubstFunc = declRefExpr->declRef.as<FunctionDeclBase>()) + return _getFuncDifferentiableLevelImpl(primalSubstFunc, recurseLimit - 1); + } + } + } + if (func->findModifier<BackwardDifferentiableAttribute>()) return FunctionDifferentiableLevel::Backward; if (func->findModifier<BackwardDerivativeAttribute>()) @@ -6849,13 +6624,19 @@ namespace Slang for (auto assocDecl : getAssociatedDeclsForDecl(func)) { - switch (assocDecl.kind) + switch (assocDecl->kind) { case DeclAssociationKind::BackwardDerivativeFunc: return FunctionDifferentiableLevel::Backward; case DeclAssociationKind::ForwardDerivativeFunc: diffLevel = FunctionDifferentiableLevel::Forward; break; + case DeclAssociationKind::PrimalSubstituteFunc: + if (auto assocFunc = as<FunctionDeclBase>(assocDecl->decl)) + { + return _getFuncDifferentiableLevelImpl(assocFunc, recurseLimit - 1); + } + break; default: break; } @@ -6971,6 +6752,10 @@ namespace Slang SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl); break; + case DeclCheckState::AttributesChecked: + SemanticsDeclAttributesVisitor(shared).dispatch(decl); + break; + case DeclCheckState::Checked: SemanticsDeclBodyVisitor(shared).dispatch(decl); break; @@ -7058,4 +6843,297 @@ namespace Slang return val; } + + template<typename TDerivativeAttr> + void checkDerivativeAttributeImpl( + SemanticsVisitor* visitor, + TDerivativeAttr* attr, + const List<Expr*>& imaginaryArguments) + { + SemanticsContext::ExprLocalScope scope; + auto ctx = visitor->withExprLocalScope(&scope); + auto subVisitor = SemanticsVisitor(ctx); + auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); + auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments); + auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) + { + attr->funcExpr = calleeDeclRef; + return; + } + } + visitor->getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + } + + template<typename TDerivativeAttr> + const char* getDerivativeAttrName() { SLANG_UNREACHABLE(""); } + + template<> + const char* getDerivativeAttrName<ForwardDerivativeAttribute>() + { + return "ForwardDerivative"; + } + template<> + const char* getDerivativeAttrName<BackwardDerivativeAttribute>() + { + return "BackwardDerivative"; + } + template<> + const char* getDerivativeAttrName<PrimalSubstituteAttribute>() + { + return "PrimalSubstitute"; + } + + List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc) + { + List<Expr*> imaginaryArguments; + for (auto param : func->getParameters()) + { + auto arg = astBuilder->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + imaginaryArguments.add(arg); + } + return imaginaryArguments; + } + + List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + { + List<Expr*> imaginaryArguments; + for (auto param : originalFuncDecl->getParameters()) + { + auto arg = visitor->getASTBuilder()->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + if (auto pairType = visitor->getDifferentialPairType(param->getType())) + { + arg->type.type = pairType; + } + imaginaryArguments.add(arg); + } + return imaginaryArguments; + } + + List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc) + { + List<Expr*> imaginaryArguments; + auto isOutParam = [&](ParamDecl* param) + { + return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr; + }; + + for (auto param : originalFuncDecl->getParameters()) + { + auto arg = visitor->getASTBuilder()->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = loc; + if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType()))) + { + arg->type.type = pairType; + if (isOutParam(param)) + { + // out T -> in T.Differential + arg->type.isLeftValue = false; + arg->type.type = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), pairType->getPrimalType()); + } + } + else + { + if (isOutParam(param)) + { + // Skip non-differentiable out params. + continue; + } + } + imaginaryArguments.add(arg); + } + if (auto diffReturnType = visitor->tryGetDifferentialType(visitor->getASTBuilder(), originalFuncDecl->returnType.type)) + { + auto arg = visitor->getASTBuilder()->create<InitializerListExpr>(); + arg->type.isLeftValue = false; + arg->type.type = diffReturnType; + arg->loc = loc; + imaginaryArguments.add(arg); + } + return imaginaryArguments; + } + + // This helper function is needed to workaround a gcc bug. + // Remove when we upgrade to a newer version of gcc. + template <typename T> + static T* _findModifier(Decl* decl) + { + return decl->findModifier<T>(); + } + + template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr> + void checkDerivativeOfAttributeImpl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + TDerivativeOfAttr* derivativeOfAttr, + DeclAssociationKind assocKind) + { + DeclRef<Decl> calleeDeclRef; + DeclRefExpr* calleeDeclRefExpr = nullptr; + HigherOrderInvokeExpr* higherOrderFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>(); + higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr; + higherOrderFuncExpr->loc = derivativeOfAttr->loc; + Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, *visitor); + if (!checkedHigherOrderFuncExpr) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc); + auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedHigherOrderFuncExpr, imaginaryArgs); + SemanticsContext::ExprLocalScope scope; + auto ctx = visitor->withExprLocalScope(&scope); + auto subVisitor = SemanticsVisitor(ctx); + auto resolved = subVisitor.ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + auto resolvedFuncExpr = as<HigherOrderInvokeExpr>(resolvedInvoke->functionExpr); + if (resolvedFuncExpr) + calleeDeclRefExpr = as<DeclRefExpr>(resolvedFuncExpr->baseFunction); + } + + if (!calleeDeclRefExpr) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + calleeDeclRef = calleeDeclRefExpr->declRef; + + auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl()); + if (!calleeFunc) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + + if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc)) + { + // The primal function already has a `[*Derivative]` attribute, this is invalid. + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::declAlreadyHasAttribute, + calleeDeclRef, + getDerivativeAttrName<TDerivativeAttr>()); + visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl()); + } + + derivativeOfAttr->funcExpr = calleeDeclRefExpr; + auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>(); + derivativeAttr->loc = derivativeOfAttr->loc; + auto outterGeneric = visitor->GetOuterGeneric(funcDecl); + auto declRef = + DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); + auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); + declRefExpr->type.type = nullptr; + derivativeAttr->args.add(declRefExpr); + derivativeAttr->funcExpr = declRefExpr; + checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr); + derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; + derivativeAttr->funcExpr = nullptr; + visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List<Expr*> imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List<Expr*> imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List<Expr*> imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc); + checkDerivativeAttributeImpl(visitor, attr, imaginaryArguments); + } + + template<typename TDerivativeAttr, typename TDerivativeOfAttr> + bool tryCheckDerivativeOfAttributeImpl( + SemanticsVisitor* visitor, + FunctionDeclBase* funcDecl, + TDerivativeOfAttr* derivativeOfAttr, + DeclAssociationKind assocKind, + const List<Expr*>& imaginaryArgsToOriginal) + { + DiagnosticSink tempSink(visitor->getSourceManager(), nullptr); + SemanticsVisitor subVisitor(visitor->withSink(&tempSink)); + checkDerivativeOfAttributeImpl<TDerivativeAttr>( + &subVisitor, + funcDecl, + derivativeOfAttr, + assocKind, + imaginaryArgsToOriginal); + return tempSink.getErrorCount() == 0; + } + + void SemanticsDeclAttributesVisitor::checkForwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeOfAttribute* attr) + { + checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>( + this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc); + } + + void SemanticsDeclAttributesVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl, BackwardDerivativeOfAttribute* attr) + { + checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>( + this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc); + } + + void SemanticsDeclAttributesVisitor::checkPrimalSubstituteOfAttribute(FunctionDeclBase* funcDecl, PrimalSubstituteOfAttribute* attr) + { + checkDerivativeOfAttributeImpl<PrimalSubstituteAttribute, PrimalSubstituteExpr>( + this, funcDecl, attr, DeclAssociationKind::PrimalSubstituteFunc); + } + + void SemanticsDeclAttributesVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) + { + // Run checking on attributes that can't be fully checked in header checking stage. + for (auto attr : decl->modifiers) + { + if (auto fwdDerivativeOfAttr = as<ForwardDerivativeOfAttribute>(attr)) + checkForwardDerivativeOfAttribute(decl, fwdDerivativeOfAttr); + else if (auto bwdDerivativeOfAttr = as<BackwardDerivativeOfAttribute>(attr)) + checkBackwardDerivativeOfAttribute(decl, bwdDerivativeOfAttr); + else if (auto primalOfAttr = as<PrimalSubstituteOfAttribute>(attr)) + checkPrimalSubstituteOfAttribute(decl, primalOfAttr); + else if (auto fwdDerivativeAttr = as<ForwardDerivativeAttribute>(attr)) + checkDerivativeAttribute(this, decl, fwdDerivativeAttr); + else if (auto bwdDerivativeAttr = as<BackwardDerivativeAttribute>(attr)) + checkDerivativeAttribute(this, decl, bwdDerivativeAttr); + else if (auto primalAttr = as<PrimalSubstituteAttribute>(attr)) + checkDerivativeAttribute(this, decl, primalAttr); + } + } } |
