summaryrefslogtreecommitdiffstats
path: root/tools/slang-capability-generator/capability-generator-main.cpp
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2024-06-01 02:38:46 -0400
committerGitHub <noreply@github.com>2024-05-31 23:38:46 -0700
commit5799281bda2f9a174b825de4058c5e8c9aa5b27f (patch)
treea9ecfe7e9320d0722a51ba8c5c101f8ffb9fb04b /tools/slang-capability-generator/capability-generator-main.cpp
parenta5cdb574b391e8adce1ce71e1e7ab3a20ce15818 (diff)
Capabilities generator inclusive join and misc (#4237)
Diffstat (limited to 'tools/slang-capability-generator/capability-generator-main.cpp')
-rw-r--r--tools/slang-capability-generator/capability-generator-main.cpp374
1 files changed, 316 insertions, 58 deletions
diff --git a/tools/slang-capability-generator/capability-generator-main.cpp b/tools/slang-capability-generator/capability-generator-main.cpp
index c34b4c49e..a8e3e397d 100644
--- a/tools/slang-capability-generator/capability-generator-main.cpp
+++ b/tools/slang-capability-generator/capability-generator-main.cpp
@@ -30,6 +30,7 @@ struct CapabilityDef;
struct CapabilityConjunctionExpr
{
List<CapabilityDef*> atoms;
+ SourceLoc sourceLoc;
};
struct CapabilityDisjunctionExpr
@@ -43,17 +44,60 @@ struct SerializedArrayView
Index count;
};
+struct CapabilitySharedContext
+{
+ CapabilityDef* ptrOfTarget = nullptr;
+ CapabilityDef* ptrOfStage = nullptr;
+};
+
+static void _removeFromOtherAtomsNotInThis(HashSet<const CapabilityDef*> thisSet, HashSet<const CapabilityDef*> otherSet, List<const CapabilityDef*> atomsToRemove)
+{
+ atomsToRemove.clear();
+ atomsToRemove.reserve(otherSet.getCount());
+ for (auto keyAtom : otherSet)
+ {
+ if (thisSet.contains(keyAtom))
+ continue;
+ atomsToRemove.add(keyAtom);
+ }
+
+ for (auto atomToRemove : atomsToRemove)
+ otherSet.remove(atomToRemove);
+}
+
struct CapabilityDef : public RefObject
{
- Index enumValue;
+public:
+ void operator=(const CapabilityDef& other)
+ {
+ this->name = other.name;
+ this->enumValue = other.enumValue;
+ this->expr = other.expr;
+ this->flavor = other.flavor;
+ this->rank = other.rank;
+ this->canonicalRepresentation = other.canonicalRepresentation;
+ this->serializedCanonicalRepresentation = other.serializedCanonicalRepresentation;
+ this->sourceLoc = other.sourceLoc;
+ this->keyAtomsPresent = other.keyAtomsPresent;
+ this->sharedContext = other.sharedContext;
+ }
+
String name;
+ Index enumValue;
CapabilityDisjunctionExpr expr;
CapabilityFlavor flavor;
- int rank;
+ /// optional, 0 is default rank.
+ int rank = 0;
List<List<CapabilityDef*>> canonicalRepresentation;
SerializedArrayView serializedCanonicalRepresentation;
+ SourceLoc sourceLoc;
+ /// Stores key atoms a CapabilityDef refers to.
+ /// Shared key atoms: key atoms shared between every individual set in a canonicalRepresentation, added together.
+ HashSet<const CapabilityDef*> keyAtomsPresent;
- CapabilityDef* getAbstractBase()
+ CapabilitySharedContext* sharedContext;
+
+ CapabilityDef* getAbstractBase() const
{
if (flavor != CapabilityFlavor::Normal)
return nullptr;
@@ -65,13 +109,80 @@ struct CapabilityDef : public RefObject
return nullptr;
return expr.conjunctions[0].atoms[0];
}
+
+ void fillKeyAtomsPresentInCannonicalRepresentation()
+ {
+ HashSet<const CapabilityDef*> sharedKeyAtomsInCanonicalSet_target;
+ HashSet<const CapabilityDef*> sharedKeyAtomsInCanonicalSet_stage;
+ HashSet<const CapabilityDef*> keyAtomsFound;
+ List<const CapabilityDef*> atomsToRemove;
+ for (auto& canonicalSet : canonicalRepresentation)
+ {
+ bool alreadySetTarget = false;
+ bool alreadySetStage = false;
+ sharedKeyAtomsInCanonicalSet_target.clear();
+ sharedKeyAtomsInCanonicalSet_stage.clear();
+
+ // find key atoms all atoms in a canonical set share.
+ for (auto& atom : canonicalSet)
+ {
+ bool foundTarget = false;
+ bool foundStage = false;
+ for (auto otherkeyAtomsPresent : atom->keyAtomsPresent)
+ {
+ auto base = otherkeyAtomsPresent->getAbstractBase();
+ // add all `target` key atoms associated with atom in canonicalSet
+ if (base == sharedContext->ptrOfTarget)
+ {
+ foundTarget = true;
+ if (!alreadySetTarget)
+ sharedKeyAtomsInCanonicalSet_target.add(otherkeyAtomsPresent);
+ }
+ // add all `stage` key atoms associated with atom in canonicalSet
+ else if (base == sharedContext->ptrOfStage)
+ {
+ foundStage = true;
+ if(!alreadySetTarget)
+ sharedKeyAtomsInCanonicalSet_stage.add(otherkeyAtomsPresent);
+ }
+ // all key atoms associated with atom
+ keyAtomsFound.add(otherkeyAtomsPresent);
+ }
+
+ // remove all not shared key atoms
+ if (foundTarget)
+ {
+ alreadySetTarget = true;
+ _removeFromOtherAtomsNotInThis(keyAtomsFound, sharedKeyAtomsInCanonicalSet_target, atomsToRemove);
+ }
+ if (foundStage)
+ {
+ alreadySetStage = true;
+ _removeFromOtherAtomsNotInThis(keyAtomsFound, sharedKeyAtomsInCanonicalSet_stage, atomsToRemove);
+ }
+ keyAtomsFound.clear();
+ }
+
+ // add all shared key atoms
+ for (auto keyAtom : sharedKeyAtomsInCanonicalSet_target)
+ this->keyAtomsPresent.add(keyAtom);
+ for (auto keyAtom : sharedKeyAtomsInCanonicalSet_stage)
+ this->keyAtomsPresent.add(keyAtom);
+ }
+ if (auto base = this->getAbstractBase())
+ keyAtomsPresent.add(this);
+ }
};
struct CapabilityDefParser
{
- CapabilityDefParser(Lexer* lexer, DiagnosticSink* sink)
+ CapabilityDefParser(
+ Lexer* lexer,
+ DiagnosticSink* sink,
+ CapabilitySharedContext& sharedContext)
: m_lexer(lexer)
, m_sink(sink)
+ , m_sharedContext(sharedContext)
{
}
@@ -80,6 +191,7 @@ struct CapabilityDefParser
Dictionary<String, CapabilityDef*> m_mapNameToCapability;
List<RefPtr<CapabilityDef>> m_defs;
+ CapabilitySharedContext& m_sharedContext;
TokenReader m_tokenReader;
@@ -126,7 +238,7 @@ struct CapabilityDefParser
m_sink->diagnose(nameToken.loc, Diagnostics::undefinedIdentifier, nameToken);
return SLANG_FAIL;
}
- if (!(advanceIf(TokenType::OpAnd) || advanceIf(TokenType::OpAdd)))
+ if (!(advanceIf(TokenType::OpAdd)))
break;
}
return SLANG_OK;
@@ -137,6 +249,7 @@ struct CapabilityDefParser
for (;;)
{
CapabilityConjunctionExpr conjunction;
+ conjunction.sourceLoc = this->m_tokenReader.m_cursor->getLoc();
SLANG_RETURN_ON_FAIL(parseConjunction(conjunction));
expr.conjunctions.add(conjunction);
if (!advanceIf(TokenType::OpBitOr))
@@ -152,6 +265,7 @@ struct CapabilityDefParser
for (;;)
{
RefPtr<CapabilityDef> def = new CapabilityDef();
+ def->sharedContext = &m_sharedContext;
def->flavor = CapabilityFlavor::Normal;
auto nextToken = m_tokenReader.advanceToken();
if (nextToken.getContent() == "alias")
@@ -207,11 +321,21 @@ struct CapabilityDefParser
}
SLANG_RETURN_ON_FAIL(readToken(TokenType::Semicolon));
m_defs.add(def);
- if (!m_mapNameToCapability.addIfNotExists(def->name, def))
+ if (!m_mapNameToCapability.addIfNotExists(def->name, m_defs.getLast()))
{
m_sink->diagnose(nextToken.loc, Diagnostics::redefinition, def->name);
return SLANG_FAIL;
}
+
+ //set abstract atom identifiers
+ if (!m_sharedContext.ptrOfTarget
+ && def->name.equals("target"))
+ m_sharedContext.ptrOfTarget = m_defs.getLast();
+ else if (!m_sharedContext.ptrOfStage
+ && def->name.equals("stage"))
+ m_sharedContext.ptrOfStage = m_defs.getLast();
+
+ def->sourceLoc = nameToken.loc;
}
return SLANG_OK;
}
@@ -220,6 +344,24 @@ struct CapabilityDefParser
struct CapabilityConjunction
{
HashSet<CapabilityDef*> atoms;
+
+ String toString() const
+ {
+ bool first = true;
+ String result = "[";
+ for (auto atom : atoms)
+ {
+ if (!first)
+ {
+ result.append(" + ");
+ }
+ first = false;
+ result.append(atom->name);
+ }
+ result.appendChar(']');
+ return result;
+ }
+
bool implies(const CapabilityConjunction& c) const
{
for (auto& atom : c.atoms)
@@ -230,6 +372,41 @@ struct CapabilityConjunction
return true;
}
+ const CapabilityDef* getAbstractAtom(CapabilityDef* defToFilterFor) const
+ {
+ for (auto* atom : this->atoms)
+ {
+ for (auto present : atom->keyAtomsPresent)
+ {
+ auto base = present->getAbstractBase();
+ if (base != defToFilterFor)
+ continue;
+ return present;
+ }
+ }
+ return nullptr;
+ }
+
+ bool shareTargetAndStageAtom(const CapabilityConjunction& other, CapabilitySharedContext& context)
+ {
+ // shared target means thisTarget==otherTarget
+ // shared stage means either `nostage + ...` or `stage == stage`
+
+ const CapabilityDef* thisTarget = this->getAbstractAtom(context.ptrOfTarget);
+ const CapabilityDef* otherTarget = other.getAbstractAtom(context.ptrOfTarget);
+
+ if (thisTarget != otherTarget && thisTarget && otherTarget)
+ return false;
+
+ const CapabilityDef* thisStage = this->getAbstractAtom(context.ptrOfStage);
+ const CapabilityDef* otherStage = other.getAbstractAtom(context.ptrOfStage);
+
+ if (thisStage != otherStage && thisStage && otherStage)
+ return false;
+
+ return true;
+ }
+
bool isImpossible() const
{
// Keep a map from an abstract base to the concrete atom defined in this conjunction that implements the base.
@@ -263,7 +440,78 @@ struct CapabilityDisjunction
{
List<CapabilityConjunction> conjunctions;
- void addConjunction(const CapabilityConjunction& c)
+ void addConjunction(DiagnosticSink* sink, SourceLoc sourceLoc, CapabilitySharedContext& context, CapabilityConjunction& c)
+ {
+ if (c.isImpossible())
+ return;
+ bool cImpliesThis = false;
+ for (Index i = 0; i < conjunctions.getCount();)
+ {
+ // implied sets will be replaced
+ if (c.implies(conjunctions[i]))
+ {
+ cImpliesThis = true;
+ conjunctions.fastRemoveAt(i);
+ }
+ else
+ i++;
+ }
+ if (cImpliesThis)
+ {
+ conjunctions.add(_Move(c));
+ return;
+ }
+
+ for (Index i = 0; i < conjunctions.getCount();)
+ {
+ if (conjunctions[i].implies(c))
+ {
+ // subset is implied, we do not need to add it.
+ return;
+ }
+ else
+ {
+ // validate we are not creating a disjunction of same targets
+ if (conjunctions[i].shareTargetAndStageAtom(c, context))
+ {
+ if (sink)
+ {
+ sink->diagnose(sourceLoc, Diagnostics::unionWithSameKeyAtomButNotSubset, conjunctions[i].toString(), c.toString());
+ sink = nullptr;
+ }
+ }
+ i++;
+ }
+ }
+ conjunctions.add(_Move(c));
+ }
+ void removeImplied()
+ {
+ for (Index i = 0; i < conjunctions.getCount(); i++)
+ {
+ for (Index ii = 0; ii < conjunctions.getCount(); ii++)
+ {
+ if (ii == i)
+ continue;
+
+ if (!conjunctions[i].implies(conjunctions[ii]))
+ continue;
+
+ if(i < ii)
+ {
+ conjunctions.fastRemoveAt(ii);
+ }
+ else
+ {
+ conjunctions.removeAt(ii);
+ i--;
+ }
+ ii--;
+ }
+ }
+ }
+
+ void inclusiveJoinConjunction(CapabilitySharedContext& context, CapabilityConjunction& c, List<CapabilityConjunction>& toAddAfter)
{
if (c.isImpossible())
return;
@@ -274,9 +522,15 @@ struct CapabilityDisjunction
}
for (Index i = 0; i < conjunctions.getCount();)
{
- if (conjunctions[i].implies(c))
+ if (conjunctions[i].shareTargetAndStageAtom(c, context))
{
- conjunctions.fastRemoveAt(i);
+ CapabilityConjunction toAddAfterSet;
+ for (auto atom : conjunctions[i].atoms)
+ toAddAfterSet.atoms.add(atom);
+ for (auto atom : c.atoms)
+ toAddAfterSet.atoms.add(atom);
+ toAddAfter.add(toAddAfterSet);
+ return;
}
else
{
@@ -286,7 +540,7 @@ struct CapabilityDisjunction
conjunctions.add(_Move(c));
}
- CapabilityDisjunction joinWith(const CapabilityDisjunction& other)
+ CapabilityDisjunction joinWith(DiagnosticSink* sink, SourceLoc sourceLoc, CapabilitySharedContext& context, const CapabilityDisjunction& other)
{
if (conjunctions.getCount() == 0)
{
@@ -308,9 +562,14 @@ struct CapabilityDisjunction
newC.atoms.add(atom);
for (auto atom : thatC.atoms)
newC.atoms.add(atom);
- result.addConjunction(_Move(newC));
+ result.addConjunction(sink, sourceLoc, context, newC);
}
}
+
+ // incompatible abstract atoms
+ if (result.conjunctions.getCount() == 0)
+ sink->diagnose(sourceLoc, Diagnostics::invalidJoinInGenerator);
+
return result;
}
@@ -353,18 +612,18 @@ CapabilityDisjunction getCanonicalRepresentation(CapabilityDef* def)
return result;
}
-CapabilityDisjunction evaluateConjunction(const List<CapabilityDef*>& atoms)
+CapabilityDisjunction evaluateConjunction(DiagnosticSink* sink, SourceLoc sourceLoc, CapabilitySharedContext& context, const List<CapabilityDef*>& atoms)
{
CapabilityDisjunction result;
- for (auto& def : atoms)
+ for (auto* def : atoms)
{
CapabilityDisjunction defCanonical = getCanonicalRepresentation(def);
- result = result.joinWith(defCanonical);
+ result = result.joinWith(sink, sourceLoc, context, defCanonical);
}
return result;
}
-void calcCanonicalRepresentation(CapabilityDef* def, const List<CapabilityDef*>& mapEnumValueToDef)
+void calcCanonicalRepresentation(DiagnosticSink* sink, CapabilityDef* def, const List<CapabilityDef*>& mapEnumValueToDef)
{
CapabilityDisjunction disjunction;
if (def->flavor == CapabilityFlavor::Normal)
@@ -376,54 +635,47 @@ void calcCanonicalRepresentation(CapabilityDef* def, const List<CapabilityDef*>&
CapabilityDisjunction exprVal;
for (auto& c : def->expr.conjunctions)
{
- CapabilityDisjunction evalD = evaluateConjunction(c.atoms);
+ CapabilityDisjunction evalD = evaluateConjunction(sink, c.sourceLoc, *def->sharedContext, c.atoms);
+ List<CapabilityConjunction> toAddAfter;
for (auto& cc : evalD.conjunctions)
- exprVal.addConjunction(cc);
+ {
+ exprVal.inclusiveJoinConjunction(*def->sharedContext, cc, toAddAfter);
+ }
+ for (auto& i : toAddAfter)
+ exprVal.conjunctions.add(i);
+ if (toAddAfter.getCount() > 0)
+ exprVal.removeImplied();
}
- disjunction = disjunction.joinWith(exprVal);
+ disjunction = disjunction.joinWith(sink, def->sourceLoc, *def->sharedContext, exprVal);
def->canonicalRepresentation = disjunction.canonicalize();
+ def->fillKeyAtomsPresentInCannonicalRepresentation();
}
-void calcCanonicalRepresentations(const List<RefPtr<CapabilityDef>>& defs, const List<CapabilityDef*>& mapEnumValueToDef)
+void calcCanonicalRepresentations(DiagnosticSink* sink, List<RefPtr<CapabilityDef>>& defs, const List<CapabilityDef*>& mapEnumValueToDef)
{
for (auto def : defs)
- calcCanonicalRepresentation(def, mapEnumValueToDef);
-}
-
-const Index kUnusedEnumValue = -1;
-
-// Check if "def" uses a named ("name"/enumValueOfAbstract) abstract atom. If true, assign
-// the enumValue of the found abstract atom (cache the unique ID) and increment counter.
-bool maybeProcessConcreteAtomForAbstractCapability(CapabilityDef* def, Index& enumValueOfAbstract, const String& name, Index& counter)
-{
- if (def->getAbstractBase()
- && (
- (enumValueOfAbstract == def->getAbstractBase()->enumValue)
- ||
- (enumValueOfAbstract == kUnusedEnumValue && def->getAbstractBase()->name.equals(name))
- )
- )
- {
- counter++;
- enumValueOfAbstract = def->getAbstractBase()->enumValue;
- return true;
- }
- return false;
+ calcCanonicalRepresentation(sink, def, mapEnumValueToDef);
}
void outputUIntSetAsBufferValues(const String& nameOfBuffer, StringBuilder& resultBuilder, UIntSet& set)
{
// store UIntSet::Element as uint8_t to stay sizeof(UIntSet::Element) independent.
// underlying type may change, bits stay the same.
- resultBuilder << "const static CapabilityAtomSet " << nameOfBuffer << " = CapabilityAtomSet({\n";
- for (auto i : set.getBuffer())
+ resultBuilder << "inline static CapabilityAtomSet generate_" << nameOfBuffer << "()\n";
+ resultBuilder << "{\n";
+ resultBuilder << " CapabilityAtomSet generatedSet;\n";
+
+ for (Index i = 0; i < set.getBuffer().getCount(); i++)
{
- resultBuilder << " UIntSet::Element(" << i << "U),\n";
+ resultBuilder << " generatedSet.addRawElement(UIntSet::Element(" << set.getBuffer()[i] << "), " << i << ");\n";
}
- resultBuilder << " 0\n});\n";
+ resultBuilder << " return generatedSet;\n";
+ resultBuilder << "}\n";
+
+ resultBuilder << "const static CapabilityAtomSet " << nameOfBuffer << " = generate_" << nameOfBuffer << "();\n";
}
-SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringBuilder& sbHeader, StringBuilder& sbCpp)
+SlangResult generateDefinitions(DiagnosticSink* sink, List<RefPtr<CapabilityDef>>& defs, StringBuilder& sbHeader, StringBuilder& sbCpp)
{
sbHeader << "enum class CapabilityAtom\n{\n";
@@ -482,9 +734,7 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB
sbHeader << " Count\n";
sbHeader << "};\n";
- Index enumValueOfTarget = kUnusedEnumValue;
Index targetCount = 0;
- Index enumValueOfStage = kUnusedEnumValue;
Index stageCount = 0;
UIntSet anyTargetAtomSet{};
@@ -494,10 +744,16 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB
for (auto def : defs)
{
- if (maybeProcessConcreteAtomForAbstractCapability(def.get(), enumValueOfTarget, "target", targetCount))
+ if (def->getAbstractBase() == def->sharedContext->ptrOfTarget)
+ {
+ targetCount++;
anyTargetAtomSet.add(def->enumValue);
- else if (maybeProcessConcreteAtomForAbstractCapability(def.get(), enumValueOfStage, "stage", stageCount))
+ }
+ else if (def->getAbstractBase() == def->sharedContext->ptrOfStage)
+ {
+ stageCount++;
anyStageAtomSet.add(def->enumValue);
+ }
}
outputUIntSetAsBufferValues("kAnyTargetUIntSetBuffer", anyTargetUIntSetHash, anyTargetAtomSet);
outputUIntSetAsBufferValues("kAnyStageUIntSetBuffer", anyStageUIntSetHash, anyStageAtomSet);
@@ -507,8 +763,7 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB
sbHeader << " kCapabilityStageCount = " << stageCount << ",\n";
sbHeader << "};\n\n";
- calcCanonicalRepresentations(defs, mapEnumValueToDef);
-
+ calcCanonicalRepresentations(sink, defs, mapEnumValueToDef);
List<String> capabiltiyNameArray;
List<SerializedArrayView> serializedCapabilityArrays;
@@ -554,7 +809,7 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB
result.count = conjunctions.getCount();
return result;
};
- for (auto& def : defs)
+ for (auto def : defs)
{
List<SerializedArrayView> conjunctions;
for (auto& c : def->canonicalRepresentation)
@@ -589,7 +844,7 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB
sbCpp << "};\n";
sbCpp << "static const CapabilityAtomInfo kCapabilityNameInfos[int(CapabilityName::Count)] = {\n";
- for (auto def : mapEnumValueToDef)
+ for (auto* def : mapEnumValueToDef)
{
if (!def)
{
@@ -636,7 +891,7 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB
}
-SlangResult parseDefFile(DiagnosticSink* sink, String inputPath, List<RefPtr<CapabilityDef>>& outDefs)
+SlangResult parseDefFile(DiagnosticSink* sink, String inputPath, List<RefPtr<CapabilityDef>>& outDefs, CapabilitySharedContext& capabilitySharedContext)
{
auto sourceManager = sink->getSourceManager();
@@ -651,7 +906,7 @@ SlangResult parseDefFile(DiagnosticSink* sink, String inputPath, List<RefPtr<Cap
namePool.setRootNamePool(&rootPool);
lexer.initialize(sourceView, sink, &namePool, sourceManager->getMemoryArena());
- CapabilityDefParser parser(&lexer, sink);
+ CapabilityDefParser parser(&lexer, sink, capabilitySharedContext);
SLANG_RETURN_ON_FAIL(parser.parseDefs());
outDefs = _Move(parser.m_defs);
@@ -708,15 +963,17 @@ int main(int argc, const char* const* argv)
sourceManager.initialize(nullptr, OSFileSystem::getExtSingleton());
DiagnosticSink sink(&sourceManager, nullptr);
List<RefPtr<CapabilityDef>> defs;
- if (SLANG_FAILED(parseDefFile(&sink, inPath, defs)))
+ CapabilitySharedContext capabilitySharedContext;
+ if (SLANG_FAILED(parseDefFile(&sink, inPath, defs, capabilitySharedContext)))
{
printDiagnostics(&sink);
return 1;
}
StringBuilder sbHeader, sbCpp;
- if (SLANG_FAILED(generateDefinitions(defs, sbHeader, sbCpp)))
+ if (SLANG_FAILED(generateDefinitions(&sink, defs, sbHeader, sbCpp)))
{
+ printDiagnostics(&sink);
return 1;
}
@@ -734,5 +991,6 @@ int main(int argc, const char* const* argv)
printDiagnostics(&sink);
return 1;
}
+ printDiagnostics(&sink);
return 0;
}