summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp944
1 files changed, 928 insertions, 16 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index fe4a7d64c..a7e197d81 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -14,7 +14,7 @@
#include "slang-syntax.h"
#include "slang-ast-synthesis.h"
#include "slang-ast-reflect.h"
-
+#include "slang-ast-iterator.h"
#include <limits>
namespace Slang
@@ -304,6 +304,416 @@ namespace Slang
void visitParamDecl(ParamDecl* paramDecl);
};
+ template<typename VisitorType>
+ struct SemanticsDeclReferenceVisitor
+ : public SemanticsDeclVisitorBase
+ , public StmtVisitor<VisitorType>
+ , public ExprVisitor<VisitorType>
+ , public ValVisitor<VisitorType>
+ , public DeclVisitor<VisitorType>
+ {
+ SemanticsDeclReferenceVisitor(SemanticsContext const& outer)
+ : SemanticsDeclVisitorBase(outer)
+ {}
+
+ List<SourceLoc> sourceLocStack;
+
+ struct PushSourceLocRAII
+ {
+ List<SourceLoc>& stack;
+ bool shouldPop = false;
+ PushSourceLocRAII(List<SourceLoc>& sourceLocStack, SourceLoc loc)
+ : stack(sourceLocStack)
+ {
+ if (loc.isValid())
+ {
+ stack.add(loc);
+ shouldPop = true;
+ }
+ }
+ ~PushSourceLocRAII()
+ {
+ if (shouldPop)
+ {
+ stack.removeLast();
+ }
+ }
+ };
+
+ virtual void processReferencedDecl(Decl* decl) = 0;
+
+ void dispatchIfNotNull(Stmt* stmt)
+ {
+ if (!stmt)
+ return;
+ PushSourceLocRAII sourceLocRAII(sourceLocStack, stmt->loc);
+ return StmtVisitor<VisitorType>::dispatch(stmt);
+ }
+ void dispatchIfNotNull(Expr* expr)
+ {
+ if (!expr)
+ return;
+ PushSourceLocRAII sourceLocRAII(sourceLocStack, expr->loc);
+ return ExprVisitor<VisitorType>::dispatch(expr);
+ }
+ void dispatchIfNotNull(Val* val)
+ {
+ if (!val)
+ return;
+ return ValVisitor<VisitorType>::dispatch(val);
+ }
+ void dispatchIfNotNull(DeclBase* val)
+ {
+ if (!val)
+ return;
+ return DeclVisitor<VisitorType>::dispatch(val);
+ }
+ // Expr Visitor
+ void visitExpr(Expr*) { }
+ void visitIndexExpr(IndexExpr* subscriptExpr)
+ {
+ for (auto arg : subscriptExpr->indexExprs)
+ dispatchIfNotNull(arg);
+ dispatchIfNotNull(subscriptExpr->baseExpression);
+ }
+
+ void visitParenExpr(ParenExpr* expr)
+ {
+ dispatchIfNotNull(expr->base);
+ }
+
+ void visitAssignExpr(AssignExpr* expr)
+ {
+ dispatchIfNotNull(expr->left);
+ dispatchIfNotNull(expr->right);
+ }
+
+ void visitGenericAppExpr(GenericAppExpr* genericAppExpr)
+ {
+ dispatchIfNotNull(genericAppExpr->functionExpr);
+ for (auto arg : genericAppExpr->arguments)
+ dispatchIfNotNull(arg);
+ }
+
+ void visitSharedTypeExpr(SharedTypeExpr* expr) { dispatchIfNotNull(expr->base.exp); }
+
+ void visitInvokeExpr(InvokeExpr* expr)
+ {
+ dispatchIfNotNull(expr->functionExpr);
+ for (auto arg : expr->arguments)
+ dispatchIfNotNull(arg);
+ }
+
+ void visitTypeCastExpr(TypeCastExpr* expr)
+ {
+ dispatchIfNotNull(expr->functionExpr);
+ for (auto arg : expr->arguments)
+ dispatchIfNotNull(arg);
+ }
+
+ void visitDerefExpr(DerefExpr* expr) { dispatchIfNotNull(expr->base); }
+ void visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr)
+ {
+ dispatchIfNotNull(expr->base);
+ }
+ void visitSwizzleExpr(SwizzleExpr* expr)
+ {
+ dispatchIfNotNull(expr->base);
+ }
+ void visitOverloadedExpr(OverloadedExpr*)
+ {
+ return;
+ }
+ void visitOverloadedExpr2(OverloadedExpr2*)
+ {
+ return;
+ }
+ void visitAggTypeCtorExpr(AggTypeCtorExpr*)
+ {
+ return;
+ }
+ void visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr)
+ {
+ dispatchIfNotNull(expr->valueArg);
+ }
+ void visitModifierCastExpr(ModifierCastExpr* expr) { dispatchIfNotNull(expr->valueArg); }
+ void visitLetExpr(LetExpr* expr)
+ {
+ dispatchIfNotNull(expr->body);
+ }
+ void visitExtractExistentialValueExpr(ExtractExistentialValueExpr* expr)
+ {
+ dispatchIfNotNull(expr->declRef.declRefBase);
+ }
+
+ void visitDeclRefExpr(DeclRefExpr* expr)
+ {
+ dispatchIfNotNull(expr->declRef.declRefBase);
+ }
+ void visitStaticMemberExpr(StaticMemberExpr* expr)
+ {
+ dispatchIfNotNull(expr->declRef.declRefBase);
+ }
+ void visitInitializerListExpr(InitializerListExpr* expr)
+ {
+ for (auto arg : expr->args)
+ {
+ dispatchIfNotNull(arg);
+ }
+ }
+
+ void visitThisExpr(ThisExpr*)
+ {
+ return;
+ }
+
+ void visitThisTypeExpr(ThisTypeExpr*) { return; }
+ void visitAndTypeExpr(AndTypeExpr* expr)
+ {
+ dispatchIfNotNull(expr->left.type);
+ dispatchIfNotNull(expr->right.type);
+ }
+ void visitPointerTypeExpr(PointerTypeExpr* expr)
+ {
+ dispatchIfNotNull(expr->base.type);
+ }
+ void visitAsTypeExpr(AsTypeExpr* expr)
+ {
+ dispatchIfNotNull(expr->value);
+ dispatchIfNotNull(expr->witnessArg);
+ }
+ void visitIsTypeExpr(IsTypeExpr* expr)
+ {
+ dispatchIfNotNull(expr->value);
+ dispatchIfNotNull(expr->witnessArg);
+ }
+ void visitMakeOptionalExpr(MakeOptionalExpr* expr)
+ {
+ dispatchIfNotNull(expr->value);
+ dispatchIfNotNull(expr->typeExpr);
+ }
+ void visitPartiallyAppliedGenericExpr(PartiallyAppliedGenericExpr*)
+ {
+ return;
+ }
+ void visitSPIRVAsmExpr(SPIRVAsmExpr*)
+ {
+ return;
+ }
+ void visitModifiedTypeExpr(ModifiedTypeExpr* expr) { dispatchIfNotNull(expr->base.type); }
+ void visitFuncTypeExpr(FuncTypeExpr* expr)
+ {
+ for (const auto& t : expr->parameters)
+ {
+ dispatchIfNotNull(t.type);
+ }
+ dispatchIfNotNull(expr->result.type);
+ }
+ void visitTupleTypeExpr(TupleTypeExpr* expr)
+ {
+ for (auto t : expr->members)
+ {
+ dispatchIfNotNull(t.type);
+ }
+ }
+ void visitTryExpr(TryExpr* expr) { dispatchIfNotNull(expr->base); }
+ void visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr)
+ {
+ dispatchIfNotNull(expr->baseFunction);
+ }
+ void visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
+ {
+ dispatchIfNotNull(expr->innerExpr);
+ }
+
+ // Stmt Visitor
+
+ void visitDeclStmt(DeclStmt* stmt) { dispatchIfNotNull(stmt->decl); }
+
+ void visitBlockStmt(BlockStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->body);
+ }
+
+ void visitSeqStmt(SeqStmt* seqStmt)
+ {
+ for (auto stmt : seqStmt->stmts)
+ dispatchIfNotNull(stmt);
+ }
+
+ void visitLabelStmt(LabelStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->innerStmt);
+ }
+
+ void visitBreakStmt(BreakStmt*) { return; }
+
+ void visitContinueStmt(ContinueStmt*) { return; }
+
+ void visitDoWhileStmt(DoWhileStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->predicate);
+ dispatchIfNotNull(stmt->statement);
+ }
+
+ void visitForStmt(ForStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->initialStatement);
+ dispatchIfNotNull(stmt->predicateExpression);
+ dispatchIfNotNull(stmt->sideEffectExpression);
+ dispatchIfNotNull(stmt->statement);
+ }
+
+ void visitCompileTimeForStmt(CompileTimeForStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->rangeBeginExpr);
+ dispatchIfNotNull(stmt->rangeEndExpr);
+ dispatchIfNotNull(stmt->body);
+ }
+
+ void visitSwitchStmt(SwitchStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->condition);
+ dispatchIfNotNull(stmt->body);
+ }
+
+ void visitCaseStmt(CaseStmt* stmt) { dispatchIfNotNull(stmt->expr); }
+
+ void visitTargetSwitchStmt(TargetSwitchStmt* stmt)
+ {
+ for (auto targetCase : stmt->targetCases)
+ dispatchIfNotNull(targetCase);
+ }
+
+ void visitTargetCaseStmt(TargetCaseStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->body);
+ }
+
+ void visitIntrinsicAsmStmt(IntrinsicAsmStmt*) { return; }
+
+ void visitDefaultStmt(DefaultStmt*) { return; }
+
+ void visitIfStmt(IfStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->predicate);
+ dispatchIfNotNull(stmt->positiveStatement);
+ dispatchIfNotNull(stmt->negativeStatement);
+ }
+
+ void visitUnparsedStmt(UnparsedStmt*) { return; }
+
+ void visitEmptyStmt(EmptyStmt*) { return; }
+
+ void visitDiscardStmt(DiscardStmt*) { return; }
+
+ void visitReturnStmt(ReturnStmt* stmt) { dispatchIfNotNull(stmt->expression); }
+
+ void visitWhileStmt(WhileStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->predicate);
+ dispatchIfNotNull(stmt->statement);
+ }
+
+ void visitGpuForeachStmt(GpuForeachStmt*) { return; }
+
+ void visitExpressionStmt(ExpressionStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->expression);
+ }
+
+ // Val Visitor
+
+ void visitDirectDeclRef(DirectDeclRef* declRef)
+ {
+ // If we have already visited, return.
+ // Otherwise add it to visited set.
+ if (!visitedVals.add(declRef))
+ return;
+
+ processReferencedDecl(declRef->getDecl());
+ }
+
+ void visitVal(Val* val)
+ {
+ // If we have already visited, return.
+ // Otherwise add it to visited set.
+ if (!visitedVals.add(val))
+ return;
+
+ for (Index i = 0; i < val->getOperandCount(); i++)
+ {
+ auto& operand = val->m_operands[i];
+ switch (operand.kind)
+ {
+ case ValNodeOperandKind::ValNode:
+ dispatchIfNotNull(val->getOperand(i));
+ break;
+ default:
+ break;
+ }
+ }
+ return;
+ }
+
+ HashSet<Val*> visitedVals;
+
+ // Decl visitor
+ void visitDeclBase(DeclBase*)
+ {}
+
+ void visitContainerDecl(ContainerDecl* decl)
+ {
+ for (auto m : decl->members)
+ {
+ dispatchIfNotNull(m);
+ }
+ }
+
+ void visitFunctionDeclBase(FunctionDeclBase* decl)
+ {
+ visitContainerDecl(decl);
+ dispatchIfNotNull(decl->body);
+ }
+
+ void visitVarDeclBase(VarDeclBase* varDecl)
+ {
+ dispatchIfNotNull(varDecl->type.type);
+ dispatchIfNotNull(varDecl->initExpr);
+ }
+ };
+
+ struct SemanticsDeclCapabilityVisitor
+ : public SemanticsDeclVisitorBase
+ , public DeclVisitor<SemanticsDeclCapabilityVisitor>
+ {
+ SemanticsDeclCapabilityVisitor(SemanticsContext const& outer)
+ : SemanticsDeclVisitorBase(outer)
+ {}
+
+ void visitDecl(Decl*) {}
+ void visitDeclGroup(DeclGroup*) {}
+
+ void checkVarDeclCommon(VarDeclBase* varDecl);
+
+ void visitVarDecl(VarDecl* varDecl)
+ {
+ checkVarDeclCommon(varDecl);
+ }
+
+ void visitParamDecl(ParamDecl* paramDecl)
+ {
+ checkVarDeclCommon(paramDecl);
+ }
+
+ void visitFunctionDeclBase(FunctionDeclBase* funcDecl);
+
+ void visitInheritanceDecl(InheritanceDecl* inheritanceDecl);
+
+ void diagnoseUndeclaredCapability(Decl* decl, const DiagnosticInfo& diagnosticInfo, const CapabilityConjunctionSet* failedAvailableSet);
+ };
+
+
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
bool isEffectivelyStatic(
Decl* decl,
@@ -528,7 +938,7 @@ namespace Slang
}
else if( auto enumCaseDeclRef = declRef.as<EnumCaseDecl>() )
{
- sema->ensureDecl(declRef.declRefBase, DeclCheckState::Checked);
+ sema->ensureDecl(declRef.declRefBase, DeclCheckState::DefinitionChecked);
QualType qualType;
qualType.type = getType(astBuilder, enumCaseDeclRef);
qualType.isLeftValue = false;
@@ -873,7 +1283,7 @@ namespace Slang
bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state)
{
- if (state != DeclCheckState::Checked)
+ if (state < DeclCheckState::DefinitionChecked)
return false;
// If we are in language server, we should skip checking all the function bodies
// except for the module or function that the user cared about.
@@ -1058,7 +1468,7 @@ namespace Slang
// If we've gone down this path, then the variable
// declaration is actually pretty far along in checking
- varDecl->setCheckState(DeclCheckState::Checked);
+ varDecl->setCheckState(DeclCheckState::DefinitionChecked);
}
else
{
@@ -1087,7 +1497,7 @@ namespace Slang
maybeInferArraySizeForVariable(varDecl);
- varDecl->setCheckState(DeclCheckState::Checked);
+ varDecl->setCheckState(DeclCheckState::DefinitionChecked);
}
}
//
@@ -1306,7 +1716,7 @@ namespace Slang
// We need to ensure that any variable doesn't introduce
// a constant with a circular definition.
//
- varDecl->setCheckState(DeclCheckState::Checked);
+ varDecl->setCheckState(DeclCheckState::DefinitionChecked);
_validateCircularVarDefinition(varDecl);
}
else
@@ -1553,7 +1963,7 @@ namespace Slang
assocTypeDef->nameAndLoc.name = getName("Differential");
assocTypeDef->type.type = satisfyingType;
assocTypeDef->parentDecl = aggTypeDecl;
- assocTypeDef->setCheckState(DeclCheckState::Checked);
+ assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked);
aggTypeDecl->members.add(assocTypeDef);
}
@@ -1861,7 +2271,7 @@ namespace Slang
//
for(auto importDecl : moduleDecl->getMembersOfType<ImportDecl>())
{
- ensureDecl(importDecl, DeclCheckState::Checked);
+ ensureDecl(importDecl, DeclCheckState::DefinitionChecked);
}
// Next, make sure all `__include` decls are processed and the referenced
@@ -1873,15 +2283,15 @@ namespace Slang
auto decl = fileDecl->members[i];
if (auto includeDecl = as<IncludeDecl>(decl))
{
- ensureDecl(includeDecl, DeclCheckState::Checked);
+ ensureDecl(includeDecl, DeclCheckState::DefinitionChecked);
}
else if (auto implementingDecl = as<ImplementingDecl>(decl))
{
- ensureDecl(implementingDecl, DeclCheckState::Checked);
+ ensureDecl(implementingDecl, DeclCheckState::DefinitionChecked);
}
else if (auto importDecl = as<ImportDecl>(decl))
{
- ensureDecl(importDecl, DeclCheckState::Checked);
+ ensureDecl(importDecl, DeclCheckState::DefinitionChecked);
}
}
};
@@ -1893,7 +2303,7 @@ namespace Slang
}
// The entire goal of semantic checking is to get all of the
- // declarations in the module up to `DeclCheckState::Checked`.
+ // declarations in the module up to `DeclCheckState::DefinitionChecked`.
//
// The main catch is that checking one declaration A up to state M
// may required that declaration B is checked up to state N.
@@ -1950,7 +2360,8 @@ namespace Slang
DeclCheckState::ReadyForReference,
DeclCheckState::ReadyForLookup,
DeclCheckState::ReadyForConformances,
- DeclCheckState::Checked
+ DeclCheckState::DefinitionChecked,
+ DeclCheckState::CapabilityChecked,
};
for(auto s : states)
{
@@ -2855,6 +3266,9 @@ namespace Slang
ThisExpr*& synThis)
{
auto synFuncDecl = m_astBuilder->create<FuncDecl>();
+ synFuncDecl->ownedScope = m_astBuilder->create<Scope>();
+ synFuncDecl->ownedScope->containerDecl = synFuncDecl;
+ synFuncDecl->ownedScope->parent = getScope(context->parentDecl);
// For now our synthesized method will use the name and source
// location of the requirement we are trying to satisfy.
@@ -2954,6 +3368,7 @@ namespace Slang
// For a non-`static` requirement, we need a `this` parameter.
//
synThis = m_astBuilder->create<ThisExpr>();
+ synThis->scope = synFuncDecl->ownedScope;
// The type of `this` in our method will be the type for
// which we are synthesizing a conformance.
@@ -3314,6 +3729,9 @@ namespace Slang
// the required accessor.
//
auto synAccessorDecl = (AccessorDecl*) m_astBuilder->createByNodeType(requiredAccessorDeclRef.getDecl()->astNodeType);
+ synAccessorDecl->ownedScope = m_astBuilder->create<Scope>();
+ synAccessorDecl->ownedScope->containerDecl = synAccessorDecl;
+ synAccessorDecl->ownedScope->parent = getScope(context->parentDecl);
// Whatever the required accessor returns, that is what our synthesized accessor will return.
//
@@ -3359,6 +3777,7 @@ namespace Slang
// a `this` expression.
//
ThisExpr* synThis = m_astBuilder->create<ThisExpr>();
+ synThis->scope = synAccessorDecl->ownedScope;
// The type of `this` in our accessor will be the type for
// which we are synthesizing a conformance.
@@ -5029,7 +5448,7 @@ namespace Slang
// the min/max tag values, or the total number of tags, so that people don't
// have to declare these as additional cases.
- enumConformanceDecl->setCheckState(DeclCheckState::Checked);
+ enumConformanceDecl->setCheckState(DeclCheckState::DefinitionChecked);
}
}
@@ -5055,7 +5474,7 @@ namespace Slang
// doing its own header checking, rather than rely on this...
caseDecl->type.type = enumType;
- ensureDecl(caseDecl, DeclCheckState::Checked);
+ ensureDecl(caseDecl, DeclCheckState::DefinitionChecked);
}
// For any enum case that didn't provide an explicit
@@ -7569,9 +7988,13 @@ namespace Slang
SemanticsDeclAttributesVisitor(shared).dispatch(decl);
break;
- case DeclCheckState::Checked:
+ case DeclCheckState::DefinitionChecked:
SemanticsDeclBodyVisitor(shared).dispatch(decl);
break;
+
+ case DeclCheckState::CapabilityChecked:
+ SemanticsDeclCapabilityVisitor(shared).dispatch(decl);
+ break;
}
}
@@ -8144,4 +8567,493 @@ namespace Slang
checkDerivativeAttribute(this, decl, primalAttr);
}
}
+
+ static void _propagateRequirement(SemanticsVisitor* visitor, CapabilitySet& resultCaps, SyntaxNode* userNode, SyntaxNode* referencedNode, const CapabilitySet& nodeCaps, SourceLoc referenceLoc)
+ {
+ auto referencedDecl = as<Decl>(referencedNode);
+
+ // Ignore cyclic references.
+ if (referencedDecl)
+ {
+ if (referencedDecl->checkState.isBeingChecked())
+ return;
+
+ ensureDecl(visitor, referencedDecl, DeclCheckState::CapabilityChecked);
+ }
+
+ if (resultCaps.implies(nodeCaps))
+ return;
+
+ auto oldCaps = resultCaps;
+ bool isAnyInvalid = resultCaps.isInvalid() || nodeCaps.isInvalid();
+ resultCaps.join(nodeCaps);
+
+ auto decl = as<Decl>(userNode);
+
+ if (!isAnyInvalid && resultCaps.isInvalid())
+ {
+ // If joining the referenced decl's requirements results an invalid capability set,
+ // then the decl is using things that require conflicting set of capabilities, and we should diagnose an error.
+ if (referencedDecl && decl)
+ {
+ visitor->getSink()->diagnose(
+ referenceLoc,
+ Diagnostics::conflictingCapabilityDueToUseOfDecl,
+ referencedDecl,
+ nodeCaps,
+ decl,
+ oldCaps);
+ }
+ else if (decl)
+ {
+ visitor->getSink()->diagnose(
+ referenceLoc,
+ Diagnostics::conflictingCapabilityDueToStatement,
+ nodeCaps,
+ decl,
+ oldCaps);
+ }
+ else
+ {
+ visitor->getSink()->diagnose(
+ referenceLoc,
+ Diagnostics::conflictingCapabilityDueToStatementEnclosingFunc,
+ nodeCaps,
+ oldCaps);
+ }
+ }
+ if (referencedDecl && decl)
+ {
+ for (auto& capSet : nodeCaps.getExpandedAtoms())
+ {
+ for (auto atom : capSet.getExpandedAtoms())
+ {
+ decl->capabilityRequirementProvenance.addIfNotExists(atom, DeclReferenceWithLoc{ referencedDecl, referenceLoc });
+ }
+ }
+ }
+ };
+
+ CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt);
+
+ template<typename ProcessFunc>
+ struct CapabilityDeclReferenceVisitor
+ : public SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>>
+ {
+ typedef SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>> Base;
+
+ const ProcessFunc& handleReferenceFunc;
+ CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, SemanticsContext const& outer)
+ : handleReferenceFunc(processFunc)
+ , SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>>(outer)
+ {
+ }
+ virtual void processReferencedDecl(Decl* decl) override
+ {
+ SourceLoc loc = SourceLoc();
+ if (Base::sourceLocStack.getCount())
+ loc = Base::sourceLocStack.getLast();
+ handleReferenceFunc(decl, decl->inferredCapabilityRequirements, loc);
+ }
+ void visitDiscardStmt(DiscardStmt* stmt)
+ {
+ handleReferenceFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc);
+ }
+ void visitTargetSwitchStmt(TargetSwitchStmt* stmt)
+ {
+ CapabilitySet set;
+ for (auto targetCase : stmt->targetCases)
+ {
+ auto targetCap = CapabilitySet(CapabilityName(targetCase->capability));
+ auto oldCap = targetCap;
+ auto bodyCap = getStatementCapabilityUsage(this, targetCase->body);
+ targetCap.join(bodyCap);
+ if (targetCap.isInvalid())
+ {
+ Base::getSink()->diagnose(targetCase->body->loc, Diagnostics::conflictingCapabilityDueToStatement, bodyCap, "target_switch", oldCap);
+ }
+ for (auto& conjunction : targetCap.getExpandedAtoms())
+ set.unionWith(conjunction);
+ }
+ set.canonicalize();
+ handleReferenceFunc(stmt, set, stmt->loc);
+ }
+ };
+
+ template<typename ProcessFunc>
+ void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, const ProcessFunc& func)
+ {
+ CapabilityDeclReferenceVisitor<ProcessFunc> visitor(func, context);
+ visitor.sourceLocStack.add(initialLoc);
+
+ if (auto val = as<Val>(node))
+ visitor.dispatchIfNotNull(val);
+ if (auto stmt = as<Stmt>(node))
+ visitor.dispatchIfNotNull(stmt);
+ if (auto expr = as<Expr>(node))
+ visitor.dispatchIfNotNull(expr);
+ if (auto decl = as<Decl>(node))
+ visitor.dispatchIfNotNull(decl);
+ }
+
+ CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt)
+ {
+ if (stmt == nullptr)
+ return CapabilitySet();
+
+ CapabilitySet inferredRequirements;
+ visitReferencedDecls(*visitor, stmt, stmt->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ {
+ _propagateRequirement(visitor, inferredRequirements, stmt, node, nodeCaps, refLoc);
+ });
+ return inferredRequirements;
+ }
+
+ void SemanticsDeclCapabilityVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
+ {
+ visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ {
+ _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc);
+ });
+ visitReferencedDecls(*this, varDecl->initExpr, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ {
+ _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc);
+ });
+ }
+
+ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl)
+ {
+ for (auto member : funcDecl->members)
+ {
+ ensureDecl(member, DeclCheckState::CapabilityChecked);
+ _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, member, member->inferredCapabilityRequirements, member->loc);
+ }
+ visitReferencedDecls(*this, funcDecl->body, funcDecl->loc, [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ {
+ _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, node, nodeCaps, refLoc);
+ });
+
+ // A decls's declared capability set is a transitive join of all parent declarations.
+ CapabilitySet declaredCaps;
+ for (Decl* parent = funcDecl; parent; parent = getParentDecl(parent))
+ {
+ CapabilitySet localDeclaredCaps;
+
+ for (auto decoration : parent->getModifiersOfType<RequireCapabilityAttribute>())
+ {
+ for (auto& set : decoration->capabilitySet.getExpandedAtoms())
+ localDeclaredCaps.unionWith(set);
+ }
+ declaredCaps.join(localDeclaredCaps);
+ }
+
+ if (!declaredCaps.isEmpty())
+ {
+ // If the function is an entrypoint, add the stage to declaredCaps.
+ if (auto entryPointAttr = funcDecl->findModifier<EntryPointAttribute>())
+ {
+ auto stageCaps = CapabilitySet(Profile(entryPointAttr->stage).getCapabilityName());
+ if (declaredCaps.isIncompatibleWith(stageCaps))
+ {
+ getSink()->diagnose(funcDecl->loc, Diagnostics::stageIsInCompatibleWithCapabilityDefinition, funcDecl, stageCaps, declaredCaps);
+ }
+ else
+ {
+ declaredCaps.join(stageCaps);
+ }
+ }
+ }
+
+ auto vis = getDeclVisibility(funcDecl);
+ if (declaredCaps.isEmpty())
+ {
+ // If the user has not declared any capabilities,
+ // we should diagnose an error if this is a public symbol.
+ if (vis == DeclVisibility::Public && !funcDecl->inferredCapabilityRequirements.isEmpty())
+ {
+ if (!getModuleDecl(funcDecl)->isInLegacyLanguage)
+ {
+ getSink()->diagnose(funcDecl->loc, Diagnostics::missingCapabilityRequirementOnPublicDecl, funcDecl);
+ }
+ }
+ }
+ else
+ {
+ if (vis == DeclVisibility::Public)
+ {
+ // For public decls, we need to enforce that the function
+ // only uses capabilities that it declares.
+ const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr;
+ if (!CapabilitySet::checkCapabilityRequirement(
+ declaredCaps,
+ funcDecl->inferredCapabilityRequirements,
+ failedAvailableCapabilityConjunction))
+ {
+ diagnoseUndeclaredCapability(funcDecl, Diagnostics::useOfUndeclaredCapability, failedAvailableCapabilityConjunction);
+ funcDecl->inferredCapabilityRequirements = declaredCaps;
+ }
+ }
+ else
+ {
+ // For internal decls, their inferred capability should be joined
+ // with the declared capabilities.
+ funcDecl->inferredCapabilityRequirements.join(declaredCaps);
+ }
+ }
+ }
+
+ void SemanticsDeclCapabilityVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
+ {
+ // Check that the implementation of an interface requirement is not using more capabilities
+ // than what's declared on the interface method.
+ if (inheritanceDecl->witnessTable)
+ {
+ for (auto& kv : inheritanceDecl->witnessTable->m_requirementDictionary)
+ {
+ if (kv.value.getFlavor() != RequirementWitness::Flavor::declRef)
+ continue;
+ auto requirementDecl = kv.key;
+ auto implDecl = kv.value.getDeclRef();
+ if (!implDecl)
+ continue;
+
+ if (getModuleDecl(implDecl.getDecl())->isInLegacyLanguage)
+ break;
+
+ ensureDecl(requirementDecl, DeclCheckState::CapabilityChecked);
+ ensureDecl(implDecl.declRefBase, DeclCheckState::CapabilityChecked);
+
+ const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr;
+ if (!CapabilitySet::checkCapabilityRequirement(
+ requirementDecl->inferredCapabilityRequirements,
+ implDecl.getDecl()->inferredCapabilityRequirements,
+ failedAvailableCapabilityConjunction))
+ {
+ diagnoseUndeclaredCapability(implDecl.getDecl(), Diagnostics::useOfUndeclaredCapabilityOfInterfaceRequirement, failedAvailableCapabilityConjunction);
+ }
+ }
+ }
+ }
+
+ DeclVisibility getDeclVisibility(Decl* decl)
+ {
+ if (as<GenericTypeParamDecl>(decl) || as<GenericValueParamDecl>(decl) || as<GenericTypeConstraintDecl>(decl))
+ {
+ auto genericDecl = as<GenericDecl>(decl->parentDecl);
+ if (!genericDecl)
+ return DeclVisibility::Default;
+ if (genericDecl->inner)
+ return getDeclVisibility(genericDecl->inner);
+ return DeclVisibility::Default;
+ }
+ if (auto genericDecl = as<GenericDecl>(decl))
+ decl = genericDecl->inner;
+ for (; decl; decl = getParentDecl(decl))
+ {
+ if (as<AccessorDecl>(decl))
+ continue;
+ if (as<EnumCaseDecl>(decl))
+ continue;
+ break;
+ }
+ if (!decl)
+ return DeclVisibility::Public;
+
+ for (auto modifier : decl->modifiers)
+ {
+ if (as<PublicModifier>(modifier))
+ return DeclVisibility::Public;
+ else if (as<InternalModifier>(modifier))
+ return DeclVisibility::Internal;
+ else if (as<PrivateModifier>(modifier))
+ return DeclVisibility::Private;
+ }
+
+ // Interface members will always have the same visibility as the interface itself.
+ if (auto interfaceDecl = findParentInterfaceDecl(decl))
+ {
+ return getDeclVisibility(interfaceDecl);
+ }
+ else if (as<NamespaceDecl>(decl))
+ {
+ return DeclVisibility::Public;
+ }
+ if (auto parentModule = getModuleDecl(decl))
+ return parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal;
+
+ return DeclVisibility::Default;
+ }
+
+ void diagnoseCapabilityProvenance(DiagnosticSink* sink, Decl* decl, CapabilityAtom missingAtom)
+ {
+ HashSet<Decl*> printedDecls;
+ auto thisModule = getModuleDecl(decl);
+ Decl* declToPrint = decl;
+ while (declToPrint)
+ {
+ printedDecls.add(declToPrint);
+ if (auto provenance = declToPrint->capabilityRequirementProvenance.tryGetValue(missingAtom))
+ {
+ sink->diagnose(provenance->referenceLoc, Diagnostics::seeUsingOf, provenance->referencedDecl);
+ declToPrint = provenance->referencedDecl;
+ if (printedDecls.contains(declToPrint))
+ break;
+ if (declToPrint->findModifier<RequireCapabilityAttribute>())
+ break;
+ auto moduleDecl = getModuleDecl(declToPrint);
+ if (thisModule != moduleDecl)
+ break;
+ if (moduleDecl && moduleDecl->isInLegacyLanguage)
+ continue;
+ if (getDeclVisibility(declToPrint) == DeclVisibility::Public)
+ break;
+ }
+ else
+ {
+ break;
+ }
+ }
+ if (declToPrint)
+ {
+ sink->diagnose(declToPrint->loc, Diagnostics::seeDefinitionOf, declToPrint);
+ }
+ }
+
+ // Print diagnostics tracing which referenced decls are not compatible with the given atom.
+ void diagnoseIncompatibleAtomProvenance(SemanticsVisitor* visitor, DiagnosticSink* sink, Decl* decl, CapabilityAtom incompatibleAtom, int traceLevels = 10)
+ {
+ Decl* refDecl = nullptr;
+ SourceLoc loc;
+ while (traceLevels > 0)
+ {
+ refDecl = nullptr;
+ visitReferencedDecls(*visitor, decl, decl->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ {
+ if (nodeCaps.isIncompatibleWith(incompatibleAtom))
+ {
+ if (auto referencedDecl = as<Decl>(node))
+ {
+ refDecl = referencedDecl;
+ loc = refLoc;
+ }
+ else
+ sink->diagnose(refLoc, Diagnostics::seeDefinitionOf, "statement");
+ }
+ });
+ if (refDecl)
+ {
+ sink->diagnose(loc, Diagnostics::seeUsingOf, refDecl);
+ decl = refDecl;
+ }
+ else
+ {
+ break;
+ }
+ traceLevels--;
+ }
+ }
+
+ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(Decl* decl, const DiagnosticInfo& diagnosticInfo, const CapabilityConjunctionSet* failedAvailableSet)
+ {
+ if (decl->inferredCapabilityRequirements.getExpandedAtoms().getCount() == 0)
+ return;
+
+ // There are two causes for why type checking failed on failedAvailableSet.
+ // The first scenario is that failedAvailableSet defines a set of capabilities on a
+ // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we have
+ // a function:
+ // [require(hlsl)] // <-- failedAvailableSet
+ // [require(cpp)]
+ // void caller()
+ // {
+ // printf(); // assume this is defined for (cpp | cuda).
+ // }
+ // In this case we should diagnose error reporting printf isn't defined on a required target.
+ //
+ // The second scenario is when the callee is using a capability that is not provided by the requirement.
+ // For example:
+ // [require(hlsl,b,c)]
+ // void caller()
+ // {
+ // useD(); // require capability (hlsl,d)
+ // }
+ // In this case we should report that useD() is using a capability that is not declared by caller.
+ //
+
+ // Now, we detect if we are case 1.
+ if (decl->inferredCapabilityRequirements.isIncompatibleWith(*failedAvailableSet))
+ {
+ // Find the most derived atom that is leading to the incompatiblity.
+ for (Index i = failedAvailableSet->getExpandedAtoms().getCount() - 1; i >= 0; i--)
+ {
+ auto atom = failedAvailableSet->getExpandedAtoms()[i];
+ if (!isDirectChildOfAbstractAtom(atom))
+ continue;
+ if (decl->inferredCapabilityRequirements.isIncompatibleWith(atom))
+ {
+ getSink()->diagnose(decl->loc, Diagnostics::declHasDependenciesNotDefinedOnTarget, decl, atom);
+ diagnoseIncompatibleAtomProvenance(this, getSink(), decl, atom);
+ return;
+ }
+ }
+ return;
+ }
+
+ // If we reach here, we are case 2.
+
+ CapabilityConjunctionSet* matchingRequirement = &decl->inferredCapabilityRequirements.getExpandedAtoms().getFirst();
+ CapabilityAtom missingAtom = matchingRequirement->getExpandedAtoms().getFirst();
+ if (missingAtom == CapabilityAtom::Invalid)
+ return;
+
+ if (failedAvailableSet)
+ {
+ Int maxIntersectionCount = 0;
+ for (auto& usedSet : decl->inferredCapabilityRequirements.getExpandedAtoms())
+ {
+ auto intersection = usedSet.countIntersectionWith(*failedAvailableSet);
+ if (intersection > maxIntersectionCount)
+ {
+ matchingRequirement = &usedSet;
+ maxIntersectionCount = intersection;
+ }
+ }
+ Index pos = 0;
+ for (Index i = 0; i < matchingRequirement->getExpandedAtoms().getCount(); i++)
+ {
+ auto atom = matchingRequirement->getExpandedAtoms()[i];
+ while (pos < failedAvailableSet->getExpandedAtoms().getCount())
+ {
+ if (failedAvailableSet->getExpandedAtoms()[pos] < atom)
+ pos++;
+ else
+ break;
+ }
+
+ if (pos >= failedAvailableSet->getExpandedAtoms().getCount() ||
+ failedAvailableSet->getExpandedAtoms()[pos] != atom)
+ {
+ missingAtom = atom;
+ break;
+ }
+ }
+
+ // Select the most derived atom of `missingAtom`.
+ for (Index i = matchingRequirement->getExpandedAtoms().getCount() - 1; i >= 0 ; i--)
+ {
+ auto atom = matchingRequirement->getExpandedAtoms()[i];
+ if (CapabilityConjunctionSet(atom).implies(missingAtom))
+ {
+ missingAtom = atom;
+ break;
+ }
+ }
+ }
+
+ getSink()->diagnose(decl->loc, diagnosticInfo, decl, missingAtom);
+
+ // Print provenances.
+ diagnoseCapabilityProvenance(getSink(), decl, missingAtom);
+ }
+
}