summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2024-04-19 23:18:40 -0400
committerGitHub <noreply@github.com>2024-04-19 20:18:40 -0700
commitf9bcad35562c1f08638e6d3eb397d370d7d2f8f8 (patch)
tree4e2a993689209bd5b597263922af03cb87d07c3d /source/slang/slang-check-decl.cpp
parent2da28c50d9c3699692eccde4b86d0b8d2323e55c (diff)
Initial pass to add capability declarations to stdlib intrinsics. (#3912)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp81
1 files changed, 64 insertions, 17 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index d819121bc..e8cc01ef2 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -354,6 +354,8 @@ namespace Slang
virtual void processReferencedDecl(Decl* decl) = 0;
+ virtual void processDeclModifiers(Decl* decl) = 0;
+
void dispatchIfNotNull(Stmt* stmt)
{
if (!stmt)
@@ -462,6 +464,7 @@ namespace Slang
{
dispatchIfNotNull(expr->type.type);
dispatchIfNotNull(expr->declRef.declRefBase);
+ processDeclModifiers(expr->declRef.getDecl());
}
void visitStaticMemberExpr(StaticMemberExpr* expr)
{
@@ -9813,10 +9816,11 @@ namespace Slang
typedef SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>> Base;
const ProcessFunc& handleReferenceFunc;
-
+ RequireCapabilityAttribute* maybeRequireCapability;
SemanticsContext& outerContext;
- CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, SemanticsContext& outer)
+ CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, RequireCapabilityAttribute* maybeRequireCapability, SemanticsContext& outer)
: handleReferenceFunc(processFunc)
+ , maybeRequireCapability(maybeRequireCapability)
, outerContext(outer)
, SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>>(outer)
{
@@ -9828,6 +9832,11 @@ namespace Slang
loc = Base::sourceLocStack.getLast();
handleReferenceFunc(decl, decl->inferredCapabilityRequirements, loc);
}
+ virtual void processDeclModifiers(Decl* decl)
+ {
+ if (decl)
+ handleReferenceFunc(decl, decl->inferredCapabilityRequirements, decl->loc);
+ }
void visitDiscardStmt(DiscardStmt* stmt)
{
handleReferenceFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc);
@@ -9835,9 +9844,42 @@ namespace Slang
void visitTargetSwitchStmt(TargetSwitchStmt* stmt)
{
CapabilitySet set;
- for (auto targetCase : stmt->targetCases)
+ auto targetCaseCount = stmt->targetCases.getCount();
+ for (Index targetCaseIndex = 0; targetCaseIndex < targetCaseCount; targetCaseIndex++)
{
- auto targetCap = CapabilitySet(CapabilityName(targetCase->capability));
+ // We may recieve a `default:` case for a `__target_switch`. If this is the case,
+ // we must resolve the target capability for a non empty set of `calling_functions_targets`:
+ // ``` default_target = calling_functions_targets-{other_case_targets} ```
+ //
+ // * `calling_functions_capability` = `requirement attribute` of the calling function; if missing
+ // we can assume it is `any_target`
+ //
+ // * `{other_case_targets}` = set of all capabilities all `case` statments target inside the `__target_switch`
+
+ // If we do not handle `default:`, the codegen will fail when trying to find a specific
+ // codegen target not handled explicitly by a `case` statment.
+ // We must also ensure the `default` case is last so we have priority to hit `case` statments and can preprocess
+ // `case` statments before the `default` case.
+ CapabilitySet targetCap;
+ if (CapabilityName(stmt->targetCases[targetCaseIndex]->capability) == CapabilityName::Invalid)
+ {
+ if (targetCaseCount - 1 != targetCaseIndex)
+ {
+ for (Index i = targetCaseIndex; i < targetCaseCount - 1; i++)
+ std::swap(stmt->targetCases[i], stmt->targetCases[i + 1]);
+ continue;
+ }
+
+ if (!maybeRequireCapability)
+ targetCap = (CapabilitySet(CapabilityName::any_target).getTargetsThisIsMissingFromOther(set));
+ else
+ targetCap = (maybeRequireCapability->capabilitySet.getTargetsThisIsMissingFromOther(set));
+ }
+ else
+ {
+ targetCap = CapabilitySet(CapabilityName(stmt->targetCases[targetCaseIndex]->capability));
+ }
+ auto targetCase = stmt->targetCases[targetCaseIndex];
auto oldCap = targetCap;
auto bodyCap = getStatementCapabilityUsage(this, targetCase->body);
targetCap.join(bodyCap);
@@ -9851,6 +9893,7 @@ namespace Slang
set.canonicalize();
handleReferenceFunc(stmt, set, stmt->loc);
}
+
void visitRequireCapabilityDecl(RequireCapabilityDecl* decl)
{
handleReferenceFunc(decl, decl->inferredCapabilityRequirements, decl->loc);
@@ -9858,9 +9901,9 @@ namespace Slang
};
template<typename ProcessFunc>
- void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, const ProcessFunc& func)
+ void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, RequireCapabilityAttribute* maybeRequireCapability, const ProcessFunc& func)
{
- CapabilityDeclReferenceVisitor<ProcessFunc> visitor(func, context);
+ CapabilityDeclReferenceVisitor<ProcessFunc> visitor(func, maybeRequireCapability, context);
visitor.sourceLocStack.add(initialLoc);
if (auto val = as<Val>(node))
@@ -9879,7 +9922,7 @@ namespace Slang
return CapabilitySet();
CapabilitySet inferredRequirements;
- visitReferencedDecls(*visitor, stmt, stmt->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ visitReferencedDecls(*visitor, stmt, stmt->loc, nullptr, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
{
_propagateRequirement(visitor, inferredRequirements, stmt, node, nodeCaps, refLoc);
});
@@ -9888,11 +9931,7 @@ namespace Slang
void SemanticsDeclCapabilityVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
{
- visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
- {
- _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc);
- });
- visitReferencedDecls(*this, varDecl->initExpr, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, varDecl->findModifier<RequireCapabilityAttribute>(), [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
{
_propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc);
});
@@ -9958,7 +9997,7 @@ namespace Slang
ensureDecl(member, DeclCheckState::CapabilityChecked);
_propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, member, member->inferredCapabilityRequirements, member->loc);
}
- visitReferencedDecls(*this, funcDecl->body, funcDecl->loc, [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ visitReferencedDecls(*this, funcDecl->body, funcDecl->loc, funcDecl->findModifier<RequireCapabilityAttribute>(), [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
{
_propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, node, nodeCaps, refLoc);
});
@@ -9972,7 +10011,7 @@ namespace Slang
_propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, parentAggTypeDecl, parentAggTypeDecl->inferredCapabilityRequirements, funcDecl->loc);
}
}
-
+
auto declaredCaps = getDeclaredCapabilitySet(funcDecl);
if (!declaredCaps.isEmpty())
@@ -9996,12 +10035,13 @@ namespace Slang
if (declaredCaps.isEmpty())
{
// If the user has not declared any capabilities,
- // we should diagnose an error if this is a public symbol.
+ // we should diagnose a warning if any_target is not
+ // a super-set by exact atoms.
if (vis == DeclVisibility::Public && !funcDecl->inferredCapabilityRequirements.isEmpty())
{
if (!getModuleDecl(funcDecl)->isInLegacyLanguage)
{
- if (funcDecl->inferredCapabilityRequirements != getAnyPlatformCapabilitySet())
+ if (!funcDecl->inferredCapabilityRequirements.isExactSubset(getAnyPlatformCapabilitySet()))
{
diagnoseCapabilityErrors(
getSink(),
@@ -10019,6 +10059,9 @@ namespace Slang
{
// For public decls, we need to enforce that the function
// only uses capabilities that it declares.
+ // At a minimum we will propagate shader requirements to our
+ // function from calling children in all cases so the parent
+ // can enforce shader targets correctly and propagate to `main`
const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr;
if (!CapabilitySet::checkCapabilityRequirement(
declaredCaps,
@@ -10028,6 +10071,8 @@ namespace Slang
diagnoseUndeclaredCapability(funcDecl, Diagnostics::useOfUndeclaredCapability, failedAvailableCapabilityConjunction);
funcDecl->inferredCapabilityRequirements = declaredCaps;
}
+ else
+ funcDecl->inferredCapabilityRequirements.simpleJoinWithSetMask(declaredCaps, CapabilityName::stage);
}
else
{
@@ -10165,7 +10210,7 @@ namespace Slang
while (traceLevels > 0)
{
refDecl = nullptr;
- visitReferencedDecls(*visitor, decl, decl->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ visitReferencedDecls(*visitor, decl, decl->loc, decl->findModifier<RequireCapabilityAttribute>(), [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
{
if (nodeCaps.isIncompatibleWith(incompatibleAtom))
{
@@ -10197,6 +10242,8 @@ namespace Slang
{
if (decl->inferredCapabilityRequirements.getExpandedAtoms().getCount() == 0)
return;
+ if(!failedAvailableSet)
+ return;
// There are two causes for why type checking failed on failedAvailableSet.
// The first scenario is that failedAvailableSet defines a set of capabilities on a