diff options
| author | ArielG-NV <159081215+ArielG-NV@users.noreply.github.com> | 2024-04-19 23:18:40 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-04-19 20:18:40 -0700 |
| commit | f9bcad35562c1f08638e6d3eb397d370d7d2f8f8 (patch) | |
| tree | 4e2a993689209bd5b597263922af03cb87d07c3d /source/slang/slang-check-decl.cpp | |
| parent | 2da28c50d9c3699692eccde4b86d0b8d2323e55c (diff) | |
Initial pass to add capability declarations to stdlib intrinsics. (#3912)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 81 |
1 files changed, 64 insertions, 17 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index d819121bc..e8cc01ef2 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -354,6 +354,8 @@ namespace Slang virtual void processReferencedDecl(Decl* decl) = 0; + virtual void processDeclModifiers(Decl* decl) = 0; + void dispatchIfNotNull(Stmt* stmt) { if (!stmt) @@ -462,6 +464,7 @@ namespace Slang { dispatchIfNotNull(expr->type.type); dispatchIfNotNull(expr->declRef.declRefBase); + processDeclModifiers(expr->declRef.getDecl()); } void visitStaticMemberExpr(StaticMemberExpr* expr) { @@ -9813,10 +9816,11 @@ namespace Slang typedef SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>> Base; const ProcessFunc& handleReferenceFunc; - + RequireCapabilityAttribute* maybeRequireCapability; SemanticsContext& outerContext; - CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, SemanticsContext& outer) + CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, RequireCapabilityAttribute* maybeRequireCapability, SemanticsContext& outer) : handleReferenceFunc(processFunc) + , maybeRequireCapability(maybeRequireCapability) , outerContext(outer) , SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>>(outer) { @@ -9828,6 +9832,11 @@ namespace Slang loc = Base::sourceLocStack.getLast(); handleReferenceFunc(decl, decl->inferredCapabilityRequirements, loc); } + virtual void processDeclModifiers(Decl* decl) + { + if (decl) + handleReferenceFunc(decl, decl->inferredCapabilityRequirements, decl->loc); + } void visitDiscardStmt(DiscardStmt* stmt) { handleReferenceFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc); @@ -9835,9 +9844,42 @@ namespace Slang void visitTargetSwitchStmt(TargetSwitchStmt* stmt) { CapabilitySet set; - for (auto targetCase : stmt->targetCases) + auto targetCaseCount = stmt->targetCases.getCount(); + for (Index targetCaseIndex = 0; targetCaseIndex < targetCaseCount; targetCaseIndex++) { - auto targetCap = CapabilitySet(CapabilityName(targetCase->capability)); + // We may recieve a `default:` case for a `__target_switch`. If this is the case, + // we must resolve the target capability for a non empty set of `calling_functions_targets`: + // ``` default_target = calling_functions_targets-{other_case_targets} ``` + // + // * `calling_functions_capability` = `requirement attribute` of the calling function; if missing + // we can assume it is `any_target` + // + // * `{other_case_targets}` = set of all capabilities all `case` statments target inside the `__target_switch` + + // If we do not handle `default:`, the codegen will fail when trying to find a specific + // codegen target not handled explicitly by a `case` statment. + // We must also ensure the `default` case is last so we have priority to hit `case` statments and can preprocess + // `case` statments before the `default` case. + CapabilitySet targetCap; + if (CapabilityName(stmt->targetCases[targetCaseIndex]->capability) == CapabilityName::Invalid) + { + if (targetCaseCount - 1 != targetCaseIndex) + { + for (Index i = targetCaseIndex; i < targetCaseCount - 1; i++) + std::swap(stmt->targetCases[i], stmt->targetCases[i + 1]); + continue; + } + + if (!maybeRequireCapability) + targetCap = (CapabilitySet(CapabilityName::any_target).getTargetsThisIsMissingFromOther(set)); + else + targetCap = (maybeRequireCapability->capabilitySet.getTargetsThisIsMissingFromOther(set)); + } + else + { + targetCap = CapabilitySet(CapabilityName(stmt->targetCases[targetCaseIndex]->capability)); + } + auto targetCase = stmt->targetCases[targetCaseIndex]; auto oldCap = targetCap; auto bodyCap = getStatementCapabilityUsage(this, targetCase->body); targetCap.join(bodyCap); @@ -9851,6 +9893,7 @@ namespace Slang set.canonicalize(); handleReferenceFunc(stmt, set, stmt->loc); } + void visitRequireCapabilityDecl(RequireCapabilityDecl* decl) { handleReferenceFunc(decl, decl->inferredCapabilityRequirements, decl->loc); @@ -9858,9 +9901,9 @@ namespace Slang }; template<typename ProcessFunc> - void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, const ProcessFunc& func) + void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, RequireCapabilityAttribute* maybeRequireCapability, const ProcessFunc& func) { - CapabilityDeclReferenceVisitor<ProcessFunc> visitor(func, context); + CapabilityDeclReferenceVisitor<ProcessFunc> visitor(func, maybeRequireCapability, context); visitor.sourceLocStack.add(initialLoc); if (auto val = as<Val>(node)) @@ -9879,7 +9922,7 @@ namespace Slang return CapabilitySet(); CapabilitySet inferredRequirements; - visitReferencedDecls(*visitor, stmt, stmt->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + visitReferencedDecls(*visitor, stmt, stmt->loc, nullptr, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) { _propagateRequirement(visitor, inferredRequirements, stmt, node, nodeCaps, refLoc); }); @@ -9888,11 +9931,7 @@ namespace Slang void SemanticsDeclCapabilityVisitor::checkVarDeclCommon(VarDeclBase* varDecl) { - visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) - { - _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc); - }); - visitReferencedDecls(*this, varDecl->initExpr, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, varDecl->findModifier<RequireCapabilityAttribute>(), [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) { _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc); }); @@ -9958,7 +9997,7 @@ namespace Slang ensureDecl(member, DeclCheckState::CapabilityChecked); _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, member, member->inferredCapabilityRequirements, member->loc); } - visitReferencedDecls(*this, funcDecl->body, funcDecl->loc, [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + visitReferencedDecls(*this, funcDecl->body, funcDecl->loc, funcDecl->findModifier<RequireCapabilityAttribute>(), [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) { _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, node, nodeCaps, refLoc); }); @@ -9972,7 +10011,7 @@ namespace Slang _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, parentAggTypeDecl, parentAggTypeDecl->inferredCapabilityRequirements, funcDecl->loc); } } - + auto declaredCaps = getDeclaredCapabilitySet(funcDecl); if (!declaredCaps.isEmpty()) @@ -9996,12 +10035,13 @@ namespace Slang if (declaredCaps.isEmpty()) { // If the user has not declared any capabilities, - // we should diagnose an error if this is a public symbol. + // we should diagnose a warning if any_target is not + // a super-set by exact atoms. if (vis == DeclVisibility::Public && !funcDecl->inferredCapabilityRequirements.isEmpty()) { if (!getModuleDecl(funcDecl)->isInLegacyLanguage) { - if (funcDecl->inferredCapabilityRequirements != getAnyPlatformCapabilitySet()) + if (!funcDecl->inferredCapabilityRequirements.isExactSubset(getAnyPlatformCapabilitySet())) { diagnoseCapabilityErrors( getSink(), @@ -10019,6 +10059,9 @@ namespace Slang { // For public decls, we need to enforce that the function // only uses capabilities that it declares. + // At a minimum we will propagate shader requirements to our + // function from calling children in all cases so the parent + // can enforce shader targets correctly and propagate to `main` const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr; if (!CapabilitySet::checkCapabilityRequirement( declaredCaps, @@ -10028,6 +10071,8 @@ namespace Slang diagnoseUndeclaredCapability(funcDecl, Diagnostics::useOfUndeclaredCapability, failedAvailableCapabilityConjunction); funcDecl->inferredCapabilityRequirements = declaredCaps; } + else + funcDecl->inferredCapabilityRequirements.simpleJoinWithSetMask(declaredCaps, CapabilityName::stage); } else { @@ -10165,7 +10210,7 @@ namespace Slang while (traceLevels > 0) { refDecl = nullptr; - visitReferencedDecls(*visitor, decl, decl->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + visitReferencedDecls(*visitor, decl, decl->loc, decl->findModifier<RequireCapabilityAttribute>(), [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) { if (nodeCaps.isIncompatibleWith(incompatibleAtom)) { @@ -10197,6 +10242,8 @@ namespace Slang { if (decl->inferredCapabilityRequirements.getExpandedAtoms().getCount() == 0) return; + if(!failedAvailableSet) + return; // There are two causes for why type checking failed on failedAvailableSet. // The first scenario is that failedAvailableSet defines a set of capabilities on a |
