summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-02-06 16:30:31 -0800
committerGitHub <noreply@github.com>2024-02-06 16:30:31 -0800
commitab41d548db376c6b52869004d1b6e21b88b4c9c8 (patch)
tree61aacddad8b8c56d77cf63ab3b650fdb28bbe0e6 /source/slang/slang-check-decl.cpp
parent6365e00179179f2bc0bc25af3d51d528501498d5 (diff)
Improve Capability System (#3555)
* Improve capability system. * Update documentation. * Tuning semantics. * LSP: hierarchical diagnostics. * Fix test. * Fix test.
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp109
1 files changed, 96 insertions, 13 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index a7e197d81..3882994da 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -448,6 +448,7 @@ namespace Slang
void visitDeclRefExpr(DeclRefExpr* expr)
{
+ dispatchIfNotNull(expr->type.type);
dispatchIfNotNull(expr->declRef.declRefBase);
}
void visitStaticMemberExpr(StaticMemberExpr* expr)
@@ -528,7 +529,10 @@ namespace Slang
// Stmt Visitor
- void visitDeclStmt(DeclStmt* stmt) { dispatchIfNotNull(stmt->decl); }
+ void visitDeclStmt(DeclStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->decl);
+ }
void visitBlockStmt(BlockStmt* stmt)
{
@@ -687,14 +691,29 @@ namespace Slang
: public SemanticsDeclVisitorBase
, public DeclVisitor<SemanticsDeclCapabilityVisitor>
{
+ CapabilitySet m_anyPlatfromCapabilitySet;
+
SemanticsDeclCapabilityVisitor(SemanticsContext const& outer)
: SemanticsDeclVisitorBase(outer)
{}
+ CapabilitySet& getAnyPlatformCapabilitySet()
+ {
+ if (m_anyPlatfromCapabilitySet.isEmpty())
+ {
+ m_anyPlatfromCapabilitySet = CapabilitySet(CapabilityName::any_target);
+ }
+ return m_anyPlatfromCapabilitySet;
+ }
+
+ CapabilitySet getDeclaredCapabilitySet(Decl* decl);
+
+
void visitDecl(Decl*) {}
void visitDeclGroup(DeclGroup*) {}
-
void checkVarDeclCommon(VarDeclBase* varDecl);
+ void visitAggTypeDeclBase(AggTypeDeclBase* decl);
+ void visitNamespaceDeclBase(NamespaceDeclBase* decl);
void visitVarDecl(VarDecl* varDecl)
{
@@ -8678,6 +8697,10 @@ namespace Slang
set.canonicalize();
handleReferenceFunc(stmt, set, stmt->loc);
}
+ void visitRequireCapabilityDecl(RequireCapabilityDecl* decl)
+ {
+ handleReferenceFunc(decl, decl->inferredCapabilityRequirements, decl->loc);
+ }
};
template<typename ProcessFunc>
@@ -8721,6 +8744,59 @@ namespace Slang
});
}
+ CapabilitySet SemanticsDeclCapabilityVisitor::getDeclaredCapabilitySet(Decl* decl)
+ {
+ // Merge a decls's declared capability set with all parent declarations.
+ // For every existing target, we want to join their requirements together.
+ // If the the parent defines additional targets, we want to add them to the disjunction set.
+ // For example:
+ // [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void foo(); }
+ // The requirement for `foo` should be glsl+glsl_ext_1 | spirv.
+ //
+ CapabilitySet declaredCaps;
+ for (Decl* parent = decl; parent; parent = getParentDecl(parent))
+ {
+ CapabilitySet localDeclaredCaps;
+ bool shouldBreak = false;
+ if (!as<AggTypeDeclBase>(parent) || parent->inferredCapabilityRequirements.isEmpty())
+ {
+ for (auto decoration : parent->getModifiersOfType<RequireCapabilityAttribute>())
+ {
+ for (auto& set : decoration->capabilitySet.getExpandedAtoms())
+ localDeclaredCaps.unionWith(set);
+ }
+ }
+ else
+ {
+ localDeclaredCaps = parent->inferredCapabilityRequirements;
+ shouldBreak = true;
+ }
+ // Merge decl's capability declaration with the parent.
+ for (auto& localConjunction : localDeclaredCaps.getExpandedAtoms())
+ {
+ if (declaredCaps.isIncompatibleWith(localConjunction))
+ declaredCaps.unionWith(localConjunction);
+ else
+ declaredCaps.join(localDeclaredCaps);
+ }
+ // If the parent already has inferred capability requirements, we should stop now
+ // since that already covers transitive parents.
+ if (shouldBreak)
+ break;
+ }
+ return declaredCaps;
+ }
+
+ void SemanticsDeclCapabilityVisitor::visitAggTypeDeclBase(AggTypeDeclBase* decl)
+ {
+ decl->inferredCapabilityRequirements = getDeclaredCapabilitySet(decl);
+ }
+
+ void SemanticsDeclCapabilityVisitor::visitNamespaceDeclBase(NamespaceDeclBase* decl)
+ {
+ decl->inferredCapabilityRequirements = getDeclaredCapabilitySet(decl);
+ }
+
void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl)
{
for (auto member : funcDecl->members)
@@ -8733,19 +8809,17 @@ namespace Slang
_propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, node, nodeCaps, refLoc);
});
- // A decls's declared capability set is a transitive join of all parent declarations.
- CapabilitySet declaredCaps;
- for (Decl* parent = funcDecl; parent; parent = getParentDecl(parent))
+ if (!isEffectivelyStatic(funcDecl))
{
- CapabilitySet localDeclaredCaps;
-
- for (auto decoration : parent->getModifiersOfType<RequireCapabilityAttribute>())
+ auto parentAggTypeDecl = getParentAggTypeDecl(funcDecl);
+ if (parentAggTypeDecl)
{
- for (auto& set : decoration->capabilitySet.getExpandedAtoms())
- localDeclaredCaps.unionWith(set);
+ ensureDecl(parentAggTypeDecl, DeclCheckState::CapabilityChecked);
+ _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, parentAggTypeDecl, parentAggTypeDecl->inferredCapabilityRequirements, funcDecl->loc);
}
- declaredCaps.join(localDeclaredCaps);
}
+
+ auto declaredCaps = getDeclaredCapabilitySet(funcDecl);
if (!declaredCaps.isEmpty())
{
@@ -8773,7 +8847,13 @@ namespace Slang
{
if (!getModuleDecl(funcDecl)->isInLegacyLanguage)
{
- getSink()->diagnose(funcDecl->loc, Diagnostics::missingCapabilityRequirementOnPublicDecl, funcDecl);
+ if (funcDecl->inferredCapabilityRequirements != getAnyPlatformCapabilitySet())
+ {
+ getSink()->diagnose(
+ funcDecl->loc,
+ Diagnostics::missingCapabilityRequirementOnPublicDecl,
+ funcDecl, funcDecl->inferredCapabilityRequirements);
+ }
}
}
}
@@ -8924,6 +9004,7 @@ namespace Slang
{
Decl* refDecl = nullptr;
SourceLoc loc;
+ HashSet<Decl*> printedDecls;
while (traceLevels > 0)
{
refDecl = nullptr;
@@ -8940,7 +9021,9 @@ namespace Slang
sink->diagnose(refLoc, Diagnostics::seeDefinitionOf, "statement");
}
});
- if (refDecl)
+ if (!refDecl)
+ break;
+ if (printedDecls.add(refDecl))
{
sink->diagnose(loc, Diagnostics::seeUsingOf, refDecl);
decl = refDecl;