diff options
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 944 |
1 files changed, 928 insertions, 16 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index fe4a7d64c..a7e197d81 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -14,7 +14,7 @@ #include "slang-syntax.h" #include "slang-ast-synthesis.h" #include "slang-ast-reflect.h" - +#include "slang-ast-iterator.h" #include <limits> namespace Slang @@ -304,6 +304,416 @@ namespace Slang void visitParamDecl(ParamDecl* paramDecl); }; + template<typename VisitorType> + struct SemanticsDeclReferenceVisitor + : public SemanticsDeclVisitorBase + , public StmtVisitor<VisitorType> + , public ExprVisitor<VisitorType> + , public ValVisitor<VisitorType> + , public DeclVisitor<VisitorType> + { + SemanticsDeclReferenceVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + + List<SourceLoc> sourceLocStack; + + struct PushSourceLocRAII + { + List<SourceLoc>& stack; + bool shouldPop = false; + PushSourceLocRAII(List<SourceLoc>& sourceLocStack, SourceLoc loc) + : stack(sourceLocStack) + { + if (loc.isValid()) + { + stack.add(loc); + shouldPop = true; + } + } + ~PushSourceLocRAII() + { + if (shouldPop) + { + stack.removeLast(); + } + } + }; + + virtual void processReferencedDecl(Decl* decl) = 0; + + void dispatchIfNotNull(Stmt* stmt) + { + if (!stmt) + return; + PushSourceLocRAII sourceLocRAII(sourceLocStack, stmt->loc); + return StmtVisitor<VisitorType>::dispatch(stmt); + } + void dispatchIfNotNull(Expr* expr) + { + if (!expr) + return; + PushSourceLocRAII sourceLocRAII(sourceLocStack, expr->loc); + return ExprVisitor<VisitorType>::dispatch(expr); + } + void dispatchIfNotNull(Val* val) + { + if (!val) + return; + return ValVisitor<VisitorType>::dispatch(val); + } + void dispatchIfNotNull(DeclBase* val) + { + if (!val) + return; + return DeclVisitor<VisitorType>::dispatch(val); + } + // Expr Visitor + void visitExpr(Expr*) { } + void visitIndexExpr(IndexExpr* subscriptExpr) + { + for (auto arg : subscriptExpr->indexExprs) + dispatchIfNotNull(arg); + dispatchIfNotNull(subscriptExpr->baseExpression); + } + + void visitParenExpr(ParenExpr* expr) + { + dispatchIfNotNull(expr->base); + } + + void visitAssignExpr(AssignExpr* expr) + { + dispatchIfNotNull(expr->left); + dispatchIfNotNull(expr->right); + } + + void visitGenericAppExpr(GenericAppExpr* genericAppExpr) + { + dispatchIfNotNull(genericAppExpr->functionExpr); + for (auto arg : genericAppExpr->arguments) + dispatchIfNotNull(arg); + } + + void visitSharedTypeExpr(SharedTypeExpr* expr) { dispatchIfNotNull(expr->base.exp); } + + void visitInvokeExpr(InvokeExpr* expr) + { + dispatchIfNotNull(expr->functionExpr); + for (auto arg : expr->arguments) + dispatchIfNotNull(arg); + } + + void visitTypeCastExpr(TypeCastExpr* expr) + { + dispatchIfNotNull(expr->functionExpr); + for (auto arg : expr->arguments) + dispatchIfNotNull(arg); + } + + void visitDerefExpr(DerefExpr* expr) { dispatchIfNotNull(expr->base); } + void visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr) + { + dispatchIfNotNull(expr->base); + } + void visitSwizzleExpr(SwizzleExpr* expr) + { + dispatchIfNotNull(expr->base); + } + void visitOverloadedExpr(OverloadedExpr*) + { + return; + } + void visitOverloadedExpr2(OverloadedExpr2*) + { + return; + } + void visitAggTypeCtorExpr(AggTypeCtorExpr*) + { + return; + } + void visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr) + { + dispatchIfNotNull(expr->valueArg); + } + void visitModifierCastExpr(ModifierCastExpr* expr) { dispatchIfNotNull(expr->valueArg); } + void visitLetExpr(LetExpr* expr) + { + dispatchIfNotNull(expr->body); + } + void visitExtractExistentialValueExpr(ExtractExistentialValueExpr* expr) + { + dispatchIfNotNull(expr->declRef.declRefBase); + } + + void visitDeclRefExpr(DeclRefExpr* expr) + { + dispatchIfNotNull(expr->declRef.declRefBase); + } + void visitStaticMemberExpr(StaticMemberExpr* expr) + { + dispatchIfNotNull(expr->declRef.declRefBase); + } + void visitInitializerListExpr(InitializerListExpr* expr) + { + for (auto arg : expr->args) + { + dispatchIfNotNull(arg); + } + } + + void visitThisExpr(ThisExpr*) + { + return; + } + + void visitThisTypeExpr(ThisTypeExpr*) { return; } + void visitAndTypeExpr(AndTypeExpr* expr) + { + dispatchIfNotNull(expr->left.type); + dispatchIfNotNull(expr->right.type); + } + void visitPointerTypeExpr(PointerTypeExpr* expr) + { + dispatchIfNotNull(expr->base.type); + } + void visitAsTypeExpr(AsTypeExpr* expr) + { + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->witnessArg); + } + void visitIsTypeExpr(IsTypeExpr* expr) + { + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->witnessArg); + } + void visitMakeOptionalExpr(MakeOptionalExpr* expr) + { + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->typeExpr); + } + void visitPartiallyAppliedGenericExpr(PartiallyAppliedGenericExpr*) + { + return; + } + void visitSPIRVAsmExpr(SPIRVAsmExpr*) + { + return; + } + void visitModifiedTypeExpr(ModifiedTypeExpr* expr) { dispatchIfNotNull(expr->base.type); } + void visitFuncTypeExpr(FuncTypeExpr* expr) + { + for (const auto& t : expr->parameters) + { + dispatchIfNotNull(t.type); + } + dispatchIfNotNull(expr->result.type); + } + void visitTupleTypeExpr(TupleTypeExpr* expr) + { + for (auto t : expr->members) + { + dispatchIfNotNull(t.type); + } + } + void visitTryExpr(TryExpr* expr) { dispatchIfNotNull(expr->base); } + void visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr) + { + dispatchIfNotNull(expr->baseFunction); + } + void visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) + { + dispatchIfNotNull(expr->innerExpr); + } + + // Stmt Visitor + + void visitDeclStmt(DeclStmt* stmt) { dispatchIfNotNull(stmt->decl); } + + void visitBlockStmt(BlockStmt* stmt) + { + dispatchIfNotNull(stmt->body); + } + + void visitSeqStmt(SeqStmt* seqStmt) + { + for (auto stmt : seqStmt->stmts) + dispatchIfNotNull(stmt); + } + + void visitLabelStmt(LabelStmt* stmt) + { + dispatchIfNotNull(stmt->innerStmt); + } + + void visitBreakStmt(BreakStmt*) { return; } + + void visitContinueStmt(ContinueStmt*) { return; } + + void visitDoWhileStmt(DoWhileStmt* stmt) + { + dispatchIfNotNull(stmt->predicate); + dispatchIfNotNull(stmt->statement); + } + + void visitForStmt(ForStmt* stmt) + { + dispatchIfNotNull(stmt->initialStatement); + dispatchIfNotNull(stmt->predicateExpression); + dispatchIfNotNull(stmt->sideEffectExpression); + dispatchIfNotNull(stmt->statement); + } + + void visitCompileTimeForStmt(CompileTimeForStmt* stmt) + { + dispatchIfNotNull(stmt->rangeBeginExpr); + dispatchIfNotNull(stmt->rangeEndExpr); + dispatchIfNotNull(stmt->body); + } + + void visitSwitchStmt(SwitchStmt* stmt) + { + dispatchIfNotNull(stmt->condition); + dispatchIfNotNull(stmt->body); + } + + void visitCaseStmt(CaseStmt* stmt) { dispatchIfNotNull(stmt->expr); } + + void visitTargetSwitchStmt(TargetSwitchStmt* stmt) + { + for (auto targetCase : stmt->targetCases) + dispatchIfNotNull(targetCase); + } + + void visitTargetCaseStmt(TargetCaseStmt* stmt) + { + dispatchIfNotNull(stmt->body); + } + + void visitIntrinsicAsmStmt(IntrinsicAsmStmt*) { return; } + + void visitDefaultStmt(DefaultStmt*) { return; } + + void visitIfStmt(IfStmt* stmt) + { + dispatchIfNotNull(stmt->predicate); + dispatchIfNotNull(stmt->positiveStatement); + dispatchIfNotNull(stmt->negativeStatement); + } + + void visitUnparsedStmt(UnparsedStmt*) { return; } + + void visitEmptyStmt(EmptyStmt*) { return; } + + void visitDiscardStmt(DiscardStmt*) { return; } + + void visitReturnStmt(ReturnStmt* stmt) { dispatchIfNotNull(stmt->expression); } + + void visitWhileStmt(WhileStmt* stmt) + { + dispatchIfNotNull(stmt->predicate); + dispatchIfNotNull(stmt->statement); + } + + void visitGpuForeachStmt(GpuForeachStmt*) { return; } + + void visitExpressionStmt(ExpressionStmt* stmt) + { + dispatchIfNotNull(stmt->expression); + } + + // Val Visitor + + void visitDirectDeclRef(DirectDeclRef* declRef) + { + // If we have already visited, return. + // Otherwise add it to visited set. + if (!visitedVals.add(declRef)) + return; + + processReferencedDecl(declRef->getDecl()); + } + + void visitVal(Val* val) + { + // If we have already visited, return. + // Otherwise add it to visited set. + if (!visitedVals.add(val)) + return; + + for (Index i = 0; i < val->getOperandCount(); i++) + { + auto& operand = val->m_operands[i]; + switch (operand.kind) + { + case ValNodeOperandKind::ValNode: + dispatchIfNotNull(val->getOperand(i)); + break; + default: + break; + } + } + return; + } + + HashSet<Val*> visitedVals; + + // Decl visitor + void visitDeclBase(DeclBase*) + {} + + void visitContainerDecl(ContainerDecl* decl) + { + for (auto m : decl->members) + { + dispatchIfNotNull(m); + } + } + + void visitFunctionDeclBase(FunctionDeclBase* decl) + { + visitContainerDecl(decl); + dispatchIfNotNull(decl->body); + } + + void visitVarDeclBase(VarDeclBase* varDecl) + { + dispatchIfNotNull(varDecl->type.type); + dispatchIfNotNull(varDecl->initExpr); + } + }; + + struct SemanticsDeclCapabilityVisitor + : public SemanticsDeclVisitorBase + , public DeclVisitor<SemanticsDeclCapabilityVisitor> + { + SemanticsDeclCapabilityVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + void checkVarDeclCommon(VarDeclBase* varDecl); + + void visitVarDecl(VarDecl* varDecl) + { + checkVarDeclCommon(varDecl); + } + + void visitParamDecl(ParamDecl* paramDecl) + { + checkVarDeclCommon(paramDecl); + } + + void visitFunctionDeclBase(FunctionDeclBase* funcDecl); + + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl); + + void diagnoseUndeclaredCapability(Decl* decl, const DiagnosticInfo& diagnosticInfo, const CapabilityConjunctionSet* failedAvailableSet); + }; + + /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? bool isEffectivelyStatic( Decl* decl, @@ -528,7 +938,7 @@ namespace Slang } else if( auto enumCaseDeclRef = declRef.as<EnumCaseDecl>() ) { - sema->ensureDecl(declRef.declRefBase, DeclCheckState::Checked); + sema->ensureDecl(declRef.declRefBase, DeclCheckState::DefinitionChecked); QualType qualType; qualType.type = getType(astBuilder, enumCaseDeclRef); qualType.isLeftValue = false; @@ -873,7 +1283,7 @@ namespace Slang bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state) { - if (state != DeclCheckState::Checked) + if (state < DeclCheckState::DefinitionChecked) return false; // If we are in language server, we should skip checking all the function bodies // except for the module or function that the user cared about. @@ -1058,7 +1468,7 @@ namespace Slang // If we've gone down this path, then the variable // declaration is actually pretty far along in checking - varDecl->setCheckState(DeclCheckState::Checked); + varDecl->setCheckState(DeclCheckState::DefinitionChecked); } else { @@ -1087,7 +1497,7 @@ namespace Slang maybeInferArraySizeForVariable(varDecl); - varDecl->setCheckState(DeclCheckState::Checked); + varDecl->setCheckState(DeclCheckState::DefinitionChecked); } } // @@ -1306,7 +1716,7 @@ namespace Slang // We need to ensure that any variable doesn't introduce // a constant with a circular definition. // - varDecl->setCheckState(DeclCheckState::Checked); + varDecl->setCheckState(DeclCheckState::DefinitionChecked); _validateCircularVarDefinition(varDecl); } else @@ -1553,7 +1963,7 @@ namespace Slang assocTypeDef->nameAndLoc.name = getName("Differential"); assocTypeDef->type.type = satisfyingType; assocTypeDef->parentDecl = aggTypeDecl; - assocTypeDef->setCheckState(DeclCheckState::Checked); + assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); aggTypeDecl->members.add(assocTypeDef); } @@ -1861,7 +2271,7 @@ namespace Slang // for(auto importDecl : moduleDecl->getMembersOfType<ImportDecl>()) { - ensureDecl(importDecl, DeclCheckState::Checked); + ensureDecl(importDecl, DeclCheckState::DefinitionChecked); } // Next, make sure all `__include` decls are processed and the referenced @@ -1873,15 +2283,15 @@ namespace Slang auto decl = fileDecl->members[i]; if (auto includeDecl = as<IncludeDecl>(decl)) { - ensureDecl(includeDecl, DeclCheckState::Checked); + ensureDecl(includeDecl, DeclCheckState::DefinitionChecked); } else if (auto implementingDecl = as<ImplementingDecl>(decl)) { - ensureDecl(implementingDecl, DeclCheckState::Checked); + ensureDecl(implementingDecl, DeclCheckState::DefinitionChecked); } else if (auto importDecl = as<ImportDecl>(decl)) { - ensureDecl(importDecl, DeclCheckState::Checked); + ensureDecl(importDecl, DeclCheckState::DefinitionChecked); } } }; @@ -1893,7 +2303,7 @@ namespace Slang } // The entire goal of semantic checking is to get all of the - // declarations in the module up to `DeclCheckState::Checked`. + // declarations in the module up to `DeclCheckState::DefinitionChecked`. // // The main catch is that checking one declaration A up to state M // may required that declaration B is checked up to state N. @@ -1950,7 +2360,8 @@ namespace Slang DeclCheckState::ReadyForReference, DeclCheckState::ReadyForLookup, DeclCheckState::ReadyForConformances, - DeclCheckState::Checked + DeclCheckState::DefinitionChecked, + DeclCheckState::CapabilityChecked, }; for(auto s : states) { @@ -2855,6 +3266,9 @@ namespace Slang ThisExpr*& synThis) { auto synFuncDecl = m_astBuilder->create<FuncDecl>(); + synFuncDecl->ownedScope = m_astBuilder->create<Scope>(); + synFuncDecl->ownedScope->containerDecl = synFuncDecl; + synFuncDecl->ownedScope->parent = getScope(context->parentDecl); // For now our synthesized method will use the name and source // location of the requirement we are trying to satisfy. @@ -2954,6 +3368,7 @@ namespace Slang // For a non-`static` requirement, we need a `this` parameter. // synThis = m_astBuilder->create<ThisExpr>(); + synThis->scope = synFuncDecl->ownedScope; // The type of `this` in our method will be the type for // which we are synthesizing a conformance. @@ -3314,6 +3729,9 @@ namespace Slang // the required accessor. // auto synAccessorDecl = (AccessorDecl*) m_astBuilder->createByNodeType(requiredAccessorDeclRef.getDecl()->astNodeType); + synAccessorDecl->ownedScope = m_astBuilder->create<Scope>(); + synAccessorDecl->ownedScope->containerDecl = synAccessorDecl; + synAccessorDecl->ownedScope->parent = getScope(context->parentDecl); // Whatever the required accessor returns, that is what our synthesized accessor will return. // @@ -3359,6 +3777,7 @@ namespace Slang // a `this` expression. // ThisExpr* synThis = m_astBuilder->create<ThisExpr>(); + synThis->scope = synAccessorDecl->ownedScope; // The type of `this` in our accessor will be the type for // which we are synthesizing a conformance. @@ -5029,7 +5448,7 @@ namespace Slang // the min/max tag values, or the total number of tags, so that people don't // have to declare these as additional cases. - enumConformanceDecl->setCheckState(DeclCheckState::Checked); + enumConformanceDecl->setCheckState(DeclCheckState::DefinitionChecked); } } @@ -5055,7 +5474,7 @@ namespace Slang // doing its own header checking, rather than rely on this... caseDecl->type.type = enumType; - ensureDecl(caseDecl, DeclCheckState::Checked); + ensureDecl(caseDecl, DeclCheckState::DefinitionChecked); } // For any enum case that didn't provide an explicit @@ -7569,9 +7988,13 @@ namespace Slang SemanticsDeclAttributesVisitor(shared).dispatch(decl); break; - case DeclCheckState::Checked: + case DeclCheckState::DefinitionChecked: SemanticsDeclBodyVisitor(shared).dispatch(decl); break; + + case DeclCheckState::CapabilityChecked: + SemanticsDeclCapabilityVisitor(shared).dispatch(decl); + break; } } @@ -8144,4 +8567,493 @@ namespace Slang checkDerivativeAttribute(this, decl, primalAttr); } } + + static void _propagateRequirement(SemanticsVisitor* visitor, CapabilitySet& resultCaps, SyntaxNode* userNode, SyntaxNode* referencedNode, const CapabilitySet& nodeCaps, SourceLoc referenceLoc) + { + auto referencedDecl = as<Decl>(referencedNode); + + // Ignore cyclic references. + if (referencedDecl) + { + if (referencedDecl->checkState.isBeingChecked()) + return; + + ensureDecl(visitor, referencedDecl, DeclCheckState::CapabilityChecked); + } + + if (resultCaps.implies(nodeCaps)) + return; + + auto oldCaps = resultCaps; + bool isAnyInvalid = resultCaps.isInvalid() || nodeCaps.isInvalid(); + resultCaps.join(nodeCaps); + + auto decl = as<Decl>(userNode); + + if (!isAnyInvalid && resultCaps.isInvalid()) + { + // If joining the referenced decl's requirements results an invalid capability set, + // then the decl is using things that require conflicting set of capabilities, and we should diagnose an error. + if (referencedDecl && decl) + { + visitor->getSink()->diagnose( + referenceLoc, + Diagnostics::conflictingCapabilityDueToUseOfDecl, + referencedDecl, + nodeCaps, + decl, + oldCaps); + } + else if (decl) + { + visitor->getSink()->diagnose( + referenceLoc, + Diagnostics::conflictingCapabilityDueToStatement, + nodeCaps, + decl, + oldCaps); + } + else + { + visitor->getSink()->diagnose( + referenceLoc, + Diagnostics::conflictingCapabilityDueToStatementEnclosingFunc, + nodeCaps, + oldCaps); + } + } + if (referencedDecl && decl) + { + for (auto& capSet : nodeCaps.getExpandedAtoms()) + { + for (auto atom : capSet.getExpandedAtoms()) + { + decl->capabilityRequirementProvenance.addIfNotExists(atom, DeclReferenceWithLoc{ referencedDecl, referenceLoc }); + } + } + } + }; + + CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt); + + template<typename ProcessFunc> + struct CapabilityDeclReferenceVisitor + : public SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>> + { + typedef SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>> Base; + + const ProcessFunc& handleReferenceFunc; + CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, SemanticsContext const& outer) + : handleReferenceFunc(processFunc) + , SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>>(outer) + { + } + virtual void processReferencedDecl(Decl* decl) override + { + SourceLoc loc = SourceLoc(); + if (Base::sourceLocStack.getCount()) + loc = Base::sourceLocStack.getLast(); + handleReferenceFunc(decl, decl->inferredCapabilityRequirements, loc); + } + void visitDiscardStmt(DiscardStmt* stmt) + { + handleReferenceFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc); + } + void visitTargetSwitchStmt(TargetSwitchStmt* stmt) + { + CapabilitySet set; + for (auto targetCase : stmt->targetCases) + { + auto targetCap = CapabilitySet(CapabilityName(targetCase->capability)); + auto oldCap = targetCap; + auto bodyCap = getStatementCapabilityUsage(this, targetCase->body); + targetCap.join(bodyCap); + if (targetCap.isInvalid()) + { + Base::getSink()->diagnose(targetCase->body->loc, Diagnostics::conflictingCapabilityDueToStatement, bodyCap, "target_switch", oldCap); + } + for (auto& conjunction : targetCap.getExpandedAtoms()) + set.unionWith(conjunction); + } + set.canonicalize(); + handleReferenceFunc(stmt, set, stmt->loc); + } + }; + + template<typename ProcessFunc> + void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, const ProcessFunc& func) + { + CapabilityDeclReferenceVisitor<ProcessFunc> visitor(func, context); + visitor.sourceLocStack.add(initialLoc); + + if (auto val = as<Val>(node)) + visitor.dispatchIfNotNull(val); + if (auto stmt = as<Stmt>(node)) + visitor.dispatchIfNotNull(stmt); + if (auto expr = as<Expr>(node)) + visitor.dispatchIfNotNull(expr); + if (auto decl = as<Decl>(node)) + visitor.dispatchIfNotNull(decl); + } + + CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt) + { + if (stmt == nullptr) + return CapabilitySet(); + + CapabilitySet inferredRequirements; + visitReferencedDecls(*visitor, stmt, stmt->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + _propagateRequirement(visitor, inferredRequirements, stmt, node, nodeCaps, refLoc); + }); + return inferredRequirements; + } + + void SemanticsDeclCapabilityVisitor::checkVarDeclCommon(VarDeclBase* varDecl) + { + visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc); + }); + visitReferencedDecls(*this, varDecl->initExpr, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc); + }); + } + + void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl) + { + for (auto member : funcDecl->members) + { + ensureDecl(member, DeclCheckState::CapabilityChecked); + _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, member, member->inferredCapabilityRequirements, member->loc); + } + visitReferencedDecls(*this, funcDecl->body, funcDecl->loc, [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, node, nodeCaps, refLoc); + }); + + // A decls's declared capability set is a transitive join of all parent declarations. + CapabilitySet declaredCaps; + for (Decl* parent = funcDecl; parent; parent = getParentDecl(parent)) + { + CapabilitySet localDeclaredCaps; + + for (auto decoration : parent->getModifiersOfType<RequireCapabilityAttribute>()) + { + for (auto& set : decoration->capabilitySet.getExpandedAtoms()) + localDeclaredCaps.unionWith(set); + } + declaredCaps.join(localDeclaredCaps); + } + + if (!declaredCaps.isEmpty()) + { + // If the function is an entrypoint, add the stage to declaredCaps. + if (auto entryPointAttr = funcDecl->findModifier<EntryPointAttribute>()) + { + auto stageCaps = CapabilitySet(Profile(entryPointAttr->stage).getCapabilityName()); + if (declaredCaps.isIncompatibleWith(stageCaps)) + { + getSink()->diagnose(funcDecl->loc, Diagnostics::stageIsInCompatibleWithCapabilityDefinition, funcDecl, stageCaps, declaredCaps); + } + else + { + declaredCaps.join(stageCaps); + } + } + } + + auto vis = getDeclVisibility(funcDecl); + if (declaredCaps.isEmpty()) + { + // If the user has not declared any capabilities, + // we should diagnose an error if this is a public symbol. + if (vis == DeclVisibility::Public && !funcDecl->inferredCapabilityRequirements.isEmpty()) + { + if (!getModuleDecl(funcDecl)->isInLegacyLanguage) + { + getSink()->diagnose(funcDecl->loc, Diagnostics::missingCapabilityRequirementOnPublicDecl, funcDecl); + } + } + } + else + { + if (vis == DeclVisibility::Public) + { + // For public decls, we need to enforce that the function + // only uses capabilities that it declares. + const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr; + if (!CapabilitySet::checkCapabilityRequirement( + declaredCaps, + funcDecl->inferredCapabilityRequirements, + failedAvailableCapabilityConjunction)) + { + diagnoseUndeclaredCapability(funcDecl, Diagnostics::useOfUndeclaredCapability, failedAvailableCapabilityConjunction); + funcDecl->inferredCapabilityRequirements = declaredCaps; + } + } + else + { + // For internal decls, their inferred capability should be joined + // with the declared capabilities. + funcDecl->inferredCapabilityRequirements.join(declaredCaps); + } + } + } + + void SemanticsDeclCapabilityVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + // Check that the implementation of an interface requirement is not using more capabilities + // than what's declared on the interface method. + if (inheritanceDecl->witnessTable) + { + for (auto& kv : inheritanceDecl->witnessTable->m_requirementDictionary) + { + if (kv.value.getFlavor() != RequirementWitness::Flavor::declRef) + continue; + auto requirementDecl = kv.key; + auto implDecl = kv.value.getDeclRef(); + if (!implDecl) + continue; + + if (getModuleDecl(implDecl.getDecl())->isInLegacyLanguage) + break; + + ensureDecl(requirementDecl, DeclCheckState::CapabilityChecked); + ensureDecl(implDecl.declRefBase, DeclCheckState::CapabilityChecked); + + const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr; + if (!CapabilitySet::checkCapabilityRequirement( + requirementDecl->inferredCapabilityRequirements, + implDecl.getDecl()->inferredCapabilityRequirements, + failedAvailableCapabilityConjunction)) + { + diagnoseUndeclaredCapability(implDecl.getDecl(), Diagnostics::useOfUndeclaredCapabilityOfInterfaceRequirement, failedAvailableCapabilityConjunction); + } + } + } + } + + DeclVisibility getDeclVisibility(Decl* decl) + { + if (as<GenericTypeParamDecl>(decl) || as<GenericValueParamDecl>(decl) || as<GenericTypeConstraintDecl>(decl)) + { + auto genericDecl = as<GenericDecl>(decl->parentDecl); + if (!genericDecl) + return DeclVisibility::Default; + if (genericDecl->inner) + return getDeclVisibility(genericDecl->inner); + return DeclVisibility::Default; + } + if (auto genericDecl = as<GenericDecl>(decl)) + decl = genericDecl->inner; + for (; decl; decl = getParentDecl(decl)) + { + if (as<AccessorDecl>(decl)) + continue; + if (as<EnumCaseDecl>(decl)) + continue; + break; + } + if (!decl) + return DeclVisibility::Public; + + for (auto modifier : decl->modifiers) + { + if (as<PublicModifier>(modifier)) + return DeclVisibility::Public; + else if (as<InternalModifier>(modifier)) + return DeclVisibility::Internal; + else if (as<PrivateModifier>(modifier)) + return DeclVisibility::Private; + } + + // Interface members will always have the same visibility as the interface itself. + if (auto interfaceDecl = findParentInterfaceDecl(decl)) + { + return getDeclVisibility(interfaceDecl); + } + else if (as<NamespaceDecl>(decl)) + { + return DeclVisibility::Public; + } + if (auto parentModule = getModuleDecl(decl)) + return parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal; + + return DeclVisibility::Default; + } + + void diagnoseCapabilityProvenance(DiagnosticSink* sink, Decl* decl, CapabilityAtom missingAtom) + { + HashSet<Decl*> printedDecls; + auto thisModule = getModuleDecl(decl); + Decl* declToPrint = decl; + while (declToPrint) + { + printedDecls.add(declToPrint); + if (auto provenance = declToPrint->capabilityRequirementProvenance.tryGetValue(missingAtom)) + { + sink->diagnose(provenance->referenceLoc, Diagnostics::seeUsingOf, provenance->referencedDecl); + declToPrint = provenance->referencedDecl; + if (printedDecls.contains(declToPrint)) + break; + if (declToPrint->findModifier<RequireCapabilityAttribute>()) + break; + auto moduleDecl = getModuleDecl(declToPrint); + if (thisModule != moduleDecl) + break; + if (moduleDecl && moduleDecl->isInLegacyLanguage) + continue; + if (getDeclVisibility(declToPrint) == DeclVisibility::Public) + break; + } + else + { + break; + } + } + if (declToPrint) + { + sink->diagnose(declToPrint->loc, Diagnostics::seeDefinitionOf, declToPrint); + } + } + + // Print diagnostics tracing which referenced decls are not compatible with the given atom. + void diagnoseIncompatibleAtomProvenance(SemanticsVisitor* visitor, DiagnosticSink* sink, Decl* decl, CapabilityAtom incompatibleAtom, int traceLevels = 10) + { + Decl* refDecl = nullptr; + SourceLoc loc; + while (traceLevels > 0) + { + refDecl = nullptr; + visitReferencedDecls(*visitor, decl, decl->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + if (nodeCaps.isIncompatibleWith(incompatibleAtom)) + { + if (auto referencedDecl = as<Decl>(node)) + { + refDecl = referencedDecl; + loc = refLoc; + } + else + sink->diagnose(refLoc, Diagnostics::seeDefinitionOf, "statement"); + } + }); + if (refDecl) + { + sink->diagnose(loc, Diagnostics::seeUsingOf, refDecl); + decl = refDecl; + } + else + { + break; + } + traceLevels--; + } + } + + void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(Decl* decl, const DiagnosticInfo& diagnosticInfo, const CapabilityConjunctionSet* failedAvailableSet) + { + if (decl->inferredCapabilityRequirements.getExpandedAtoms().getCount() == 0) + return; + + // There are two causes for why type checking failed on failedAvailableSet. + // The first scenario is that failedAvailableSet defines a set of capabilities on a + // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we have + // a function: + // [require(hlsl)] // <-- failedAvailableSet + // [require(cpp)] + // void caller() + // { + // printf(); // assume this is defined for (cpp | cuda). + // } + // In this case we should diagnose error reporting printf isn't defined on a required target. + // + // The second scenario is when the callee is using a capability that is not provided by the requirement. + // For example: + // [require(hlsl,b,c)] + // void caller() + // { + // useD(); // require capability (hlsl,d) + // } + // In this case we should report that useD() is using a capability that is not declared by caller. + // + + // Now, we detect if we are case 1. + if (decl->inferredCapabilityRequirements.isIncompatibleWith(*failedAvailableSet)) + { + // Find the most derived atom that is leading to the incompatiblity. + for (Index i = failedAvailableSet->getExpandedAtoms().getCount() - 1; i >= 0; i--) + { + auto atom = failedAvailableSet->getExpandedAtoms()[i]; + if (!isDirectChildOfAbstractAtom(atom)) + continue; + if (decl->inferredCapabilityRequirements.isIncompatibleWith(atom)) + { + getSink()->diagnose(decl->loc, Diagnostics::declHasDependenciesNotDefinedOnTarget, decl, atom); + diagnoseIncompatibleAtomProvenance(this, getSink(), decl, atom); + return; + } + } + return; + } + + // If we reach here, we are case 2. + + CapabilityConjunctionSet* matchingRequirement = &decl->inferredCapabilityRequirements.getExpandedAtoms().getFirst(); + CapabilityAtom missingAtom = matchingRequirement->getExpandedAtoms().getFirst(); + if (missingAtom == CapabilityAtom::Invalid) + return; + + if (failedAvailableSet) + { + Int maxIntersectionCount = 0; + for (auto& usedSet : decl->inferredCapabilityRequirements.getExpandedAtoms()) + { + auto intersection = usedSet.countIntersectionWith(*failedAvailableSet); + if (intersection > maxIntersectionCount) + { + matchingRequirement = &usedSet; + maxIntersectionCount = intersection; + } + } + Index pos = 0; + for (Index i = 0; i < matchingRequirement->getExpandedAtoms().getCount(); i++) + { + auto atom = matchingRequirement->getExpandedAtoms()[i]; + while (pos < failedAvailableSet->getExpandedAtoms().getCount()) + { + if (failedAvailableSet->getExpandedAtoms()[pos] < atom) + pos++; + else + break; + } + + if (pos >= failedAvailableSet->getExpandedAtoms().getCount() || + failedAvailableSet->getExpandedAtoms()[pos] != atom) + { + missingAtom = atom; + break; + } + } + + // Select the most derived atom of `missingAtom`. + for (Index i = matchingRequirement->getExpandedAtoms().getCount() - 1; i >= 0 ; i--) + { + auto atom = matchingRequirement->getExpandedAtoms()[i]; + if (CapabilityConjunctionSet(atom).implies(missingAtom)) + { + missingAtom = atom; + break; + } + } + } + + getSink()->diagnose(decl->loc, diagnosticInfo, decl, missingAtom); + + // Print provenances. + diagnoseCapabilityProvenance(getSink(), decl, missingAtom); + } + } |
