diff options
| author | ArielG-NV <159081215+ArielG-NV@users.noreply.github.com> | 2024-06-01 02:38:46 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-31 23:38:46 -0700 |
| commit | 5799281bda2f9a174b825de4058c5e8c9aa5b27f (patch) | |
| tree | a9ecfe7e9320d0722a51ba8c5c101f8ffb9fb04b /tools/slang-capability-generator/capability-generator-main.cpp | |
| parent | a5cdb574b391e8adce1ce71e1e7ab3a20ce15818 (diff) | |
Capabilities generator inclusive join and misc (#4237)
Diffstat (limited to 'tools/slang-capability-generator/capability-generator-main.cpp')
| -rw-r--r-- | tools/slang-capability-generator/capability-generator-main.cpp | 374 |
1 files changed, 316 insertions, 58 deletions
diff --git a/tools/slang-capability-generator/capability-generator-main.cpp b/tools/slang-capability-generator/capability-generator-main.cpp index c34b4c49e..a8e3e397d 100644 --- a/tools/slang-capability-generator/capability-generator-main.cpp +++ b/tools/slang-capability-generator/capability-generator-main.cpp @@ -30,6 +30,7 @@ struct CapabilityDef; struct CapabilityConjunctionExpr { List<CapabilityDef*> atoms; + SourceLoc sourceLoc; }; struct CapabilityDisjunctionExpr @@ -43,17 +44,60 @@ struct SerializedArrayView Index count; }; +struct CapabilitySharedContext +{ + CapabilityDef* ptrOfTarget = nullptr; + CapabilityDef* ptrOfStage = nullptr; +}; + +static void _removeFromOtherAtomsNotInThis(HashSet<const CapabilityDef*> thisSet, HashSet<const CapabilityDef*> otherSet, List<const CapabilityDef*> atomsToRemove) +{ + atomsToRemove.clear(); + atomsToRemove.reserve(otherSet.getCount()); + for (auto keyAtom : otherSet) + { + if (thisSet.contains(keyAtom)) + continue; + atomsToRemove.add(keyAtom); + } + + for (auto atomToRemove : atomsToRemove) + otherSet.remove(atomToRemove); +} + struct CapabilityDef : public RefObject { - Index enumValue; +public: + void operator=(const CapabilityDef& other) + { + this->name = other.name; + this->enumValue = other.enumValue; + this->expr = other.expr; + this->flavor = other.flavor; + this->rank = other.rank; + this->canonicalRepresentation = other.canonicalRepresentation; + this->serializedCanonicalRepresentation = other.serializedCanonicalRepresentation; + this->sourceLoc = other.sourceLoc; + this->keyAtomsPresent = other.keyAtomsPresent; + this->sharedContext = other.sharedContext; + } + String name; + Index enumValue; CapabilityDisjunctionExpr expr; CapabilityFlavor flavor; - int rank; + /// optional, 0 is default rank. + int rank = 0; List<List<CapabilityDef*>> canonicalRepresentation; SerializedArrayView serializedCanonicalRepresentation; + SourceLoc sourceLoc; + /// Stores key atoms a CapabilityDef refers to. + /// Shared key atoms: key atoms shared between every individual set in a canonicalRepresentation, added together. + HashSet<const CapabilityDef*> keyAtomsPresent; - CapabilityDef* getAbstractBase() + CapabilitySharedContext* sharedContext; + + CapabilityDef* getAbstractBase() const { if (flavor != CapabilityFlavor::Normal) return nullptr; @@ -65,13 +109,80 @@ struct CapabilityDef : public RefObject return nullptr; return expr.conjunctions[0].atoms[0]; } + + void fillKeyAtomsPresentInCannonicalRepresentation() + { + HashSet<const CapabilityDef*> sharedKeyAtomsInCanonicalSet_target; + HashSet<const CapabilityDef*> sharedKeyAtomsInCanonicalSet_stage; + HashSet<const CapabilityDef*> keyAtomsFound; + List<const CapabilityDef*> atomsToRemove; + for (auto& canonicalSet : canonicalRepresentation) + { + bool alreadySetTarget = false; + bool alreadySetStage = false; + sharedKeyAtomsInCanonicalSet_target.clear(); + sharedKeyAtomsInCanonicalSet_stage.clear(); + + // find key atoms all atoms in a canonical set share. + for (auto& atom : canonicalSet) + { + bool foundTarget = false; + bool foundStage = false; + for (auto otherkeyAtomsPresent : atom->keyAtomsPresent) + { + auto base = otherkeyAtomsPresent->getAbstractBase(); + // add all `target` key atoms associated with atom in canonicalSet + if (base == sharedContext->ptrOfTarget) + { + foundTarget = true; + if (!alreadySetTarget) + sharedKeyAtomsInCanonicalSet_target.add(otherkeyAtomsPresent); + } + // add all `stage` key atoms associated with atom in canonicalSet + else if (base == sharedContext->ptrOfStage) + { + foundStage = true; + if(!alreadySetTarget) + sharedKeyAtomsInCanonicalSet_stage.add(otherkeyAtomsPresent); + } + // all key atoms associated with atom + keyAtomsFound.add(otherkeyAtomsPresent); + } + + // remove all not shared key atoms + if (foundTarget) + { + alreadySetTarget = true; + _removeFromOtherAtomsNotInThis(keyAtomsFound, sharedKeyAtomsInCanonicalSet_target, atomsToRemove); + } + if (foundStage) + { + alreadySetStage = true; + _removeFromOtherAtomsNotInThis(keyAtomsFound, sharedKeyAtomsInCanonicalSet_stage, atomsToRemove); + } + keyAtomsFound.clear(); + } + + // add all shared key atoms + for (auto keyAtom : sharedKeyAtomsInCanonicalSet_target) + this->keyAtomsPresent.add(keyAtom); + for (auto keyAtom : sharedKeyAtomsInCanonicalSet_stage) + this->keyAtomsPresent.add(keyAtom); + } + if (auto base = this->getAbstractBase()) + keyAtomsPresent.add(this); + } }; struct CapabilityDefParser { - CapabilityDefParser(Lexer* lexer, DiagnosticSink* sink) + CapabilityDefParser( + Lexer* lexer, + DiagnosticSink* sink, + CapabilitySharedContext& sharedContext) : m_lexer(lexer) , m_sink(sink) + , m_sharedContext(sharedContext) { } @@ -80,6 +191,7 @@ struct CapabilityDefParser Dictionary<String, CapabilityDef*> m_mapNameToCapability; List<RefPtr<CapabilityDef>> m_defs; + CapabilitySharedContext& m_sharedContext; TokenReader m_tokenReader; @@ -126,7 +238,7 @@ struct CapabilityDefParser m_sink->diagnose(nameToken.loc, Diagnostics::undefinedIdentifier, nameToken); return SLANG_FAIL; } - if (!(advanceIf(TokenType::OpAnd) || advanceIf(TokenType::OpAdd))) + if (!(advanceIf(TokenType::OpAdd))) break; } return SLANG_OK; @@ -137,6 +249,7 @@ struct CapabilityDefParser for (;;) { CapabilityConjunctionExpr conjunction; + conjunction.sourceLoc = this->m_tokenReader.m_cursor->getLoc(); SLANG_RETURN_ON_FAIL(parseConjunction(conjunction)); expr.conjunctions.add(conjunction); if (!advanceIf(TokenType::OpBitOr)) @@ -152,6 +265,7 @@ struct CapabilityDefParser for (;;) { RefPtr<CapabilityDef> def = new CapabilityDef(); + def->sharedContext = &m_sharedContext; def->flavor = CapabilityFlavor::Normal; auto nextToken = m_tokenReader.advanceToken(); if (nextToken.getContent() == "alias") @@ -207,11 +321,21 @@ struct CapabilityDefParser } SLANG_RETURN_ON_FAIL(readToken(TokenType::Semicolon)); m_defs.add(def); - if (!m_mapNameToCapability.addIfNotExists(def->name, def)) + if (!m_mapNameToCapability.addIfNotExists(def->name, m_defs.getLast())) { m_sink->diagnose(nextToken.loc, Diagnostics::redefinition, def->name); return SLANG_FAIL; } + + //set abstract atom identifiers + if (!m_sharedContext.ptrOfTarget + && def->name.equals("target")) + m_sharedContext.ptrOfTarget = m_defs.getLast(); + else if (!m_sharedContext.ptrOfStage + && def->name.equals("stage")) + m_sharedContext.ptrOfStage = m_defs.getLast(); + + def->sourceLoc = nameToken.loc; } return SLANG_OK; } @@ -220,6 +344,24 @@ struct CapabilityDefParser struct CapabilityConjunction { HashSet<CapabilityDef*> atoms; + + String toString() const + { + bool first = true; + String result = "["; + for (auto atom : atoms) + { + if (!first) + { + result.append(" + "); + } + first = false; + result.append(atom->name); + } + result.appendChar(']'); + return result; + } + bool implies(const CapabilityConjunction& c) const { for (auto& atom : c.atoms) @@ -230,6 +372,41 @@ struct CapabilityConjunction return true; } + const CapabilityDef* getAbstractAtom(CapabilityDef* defToFilterFor) const + { + for (auto* atom : this->atoms) + { + for (auto present : atom->keyAtomsPresent) + { + auto base = present->getAbstractBase(); + if (base != defToFilterFor) + continue; + return present; + } + } + return nullptr; + } + + bool shareTargetAndStageAtom(const CapabilityConjunction& other, CapabilitySharedContext& context) + { + // shared target means thisTarget==otherTarget + // shared stage means either `nostage + ...` or `stage == stage` + + const CapabilityDef* thisTarget = this->getAbstractAtom(context.ptrOfTarget); + const CapabilityDef* otherTarget = other.getAbstractAtom(context.ptrOfTarget); + + if (thisTarget != otherTarget && thisTarget && otherTarget) + return false; + + const CapabilityDef* thisStage = this->getAbstractAtom(context.ptrOfStage); + const CapabilityDef* otherStage = other.getAbstractAtom(context.ptrOfStage); + + if (thisStage != otherStage && thisStage && otherStage) + return false; + + return true; + } + bool isImpossible() const { // Keep a map from an abstract base to the concrete atom defined in this conjunction that implements the base. @@ -263,7 +440,78 @@ struct CapabilityDisjunction { List<CapabilityConjunction> conjunctions; - void addConjunction(const CapabilityConjunction& c) + void addConjunction(DiagnosticSink* sink, SourceLoc sourceLoc, CapabilitySharedContext& context, CapabilityConjunction& c) + { + if (c.isImpossible()) + return; + bool cImpliesThis = false; + for (Index i = 0; i < conjunctions.getCount();) + { + // implied sets will be replaced + if (c.implies(conjunctions[i])) + { + cImpliesThis = true; + conjunctions.fastRemoveAt(i); + } + else + i++; + } + if (cImpliesThis) + { + conjunctions.add(_Move(c)); + return; + } + + for (Index i = 0; i < conjunctions.getCount();) + { + if (conjunctions[i].implies(c)) + { + // subset is implied, we do not need to add it. + return; + } + else + { + // validate we are not creating a disjunction of same targets + if (conjunctions[i].shareTargetAndStageAtom(c, context)) + { + if (sink) + { + sink->diagnose(sourceLoc, Diagnostics::unionWithSameKeyAtomButNotSubset, conjunctions[i].toString(), c.toString()); + sink = nullptr; + } + } + i++; + } + } + conjunctions.add(_Move(c)); + } + void removeImplied() + { + for (Index i = 0; i < conjunctions.getCount(); i++) + { + for (Index ii = 0; ii < conjunctions.getCount(); ii++) + { + if (ii == i) + continue; + + if (!conjunctions[i].implies(conjunctions[ii])) + continue; + + if(i < ii) + { + conjunctions.fastRemoveAt(ii); + } + else + { + conjunctions.removeAt(ii); + i--; + } + ii--; + } + } + } + + void inclusiveJoinConjunction(CapabilitySharedContext& context, CapabilityConjunction& c, List<CapabilityConjunction>& toAddAfter) { if (c.isImpossible()) return; @@ -274,9 +522,15 @@ struct CapabilityDisjunction } for (Index i = 0; i < conjunctions.getCount();) { - if (conjunctions[i].implies(c)) + if (conjunctions[i].shareTargetAndStageAtom(c, context)) { - conjunctions.fastRemoveAt(i); + CapabilityConjunction toAddAfterSet; + for (auto atom : conjunctions[i].atoms) + toAddAfterSet.atoms.add(atom); + for (auto atom : c.atoms) + toAddAfterSet.atoms.add(atom); + toAddAfter.add(toAddAfterSet); + return; } else { @@ -286,7 +540,7 @@ struct CapabilityDisjunction conjunctions.add(_Move(c)); } - CapabilityDisjunction joinWith(const CapabilityDisjunction& other) + CapabilityDisjunction joinWith(DiagnosticSink* sink, SourceLoc sourceLoc, CapabilitySharedContext& context, const CapabilityDisjunction& other) { if (conjunctions.getCount() == 0) { @@ -308,9 +562,14 @@ struct CapabilityDisjunction newC.atoms.add(atom); for (auto atom : thatC.atoms) newC.atoms.add(atom); - result.addConjunction(_Move(newC)); + result.addConjunction(sink, sourceLoc, context, newC); } } + + // incompatible abstract atoms + if (result.conjunctions.getCount() == 0) + sink->diagnose(sourceLoc, Diagnostics::invalidJoinInGenerator); + return result; } @@ -353,18 +612,18 @@ CapabilityDisjunction getCanonicalRepresentation(CapabilityDef* def) return result; } -CapabilityDisjunction evaluateConjunction(const List<CapabilityDef*>& atoms) +CapabilityDisjunction evaluateConjunction(DiagnosticSink* sink, SourceLoc sourceLoc, CapabilitySharedContext& context, const List<CapabilityDef*>& atoms) { CapabilityDisjunction result; - for (auto& def : atoms) + for (auto* def : atoms) { CapabilityDisjunction defCanonical = getCanonicalRepresentation(def); - result = result.joinWith(defCanonical); + result = result.joinWith(sink, sourceLoc, context, defCanonical); } return result; } -void calcCanonicalRepresentation(CapabilityDef* def, const List<CapabilityDef*>& mapEnumValueToDef) +void calcCanonicalRepresentation(DiagnosticSink* sink, CapabilityDef* def, const List<CapabilityDef*>& mapEnumValueToDef) { CapabilityDisjunction disjunction; if (def->flavor == CapabilityFlavor::Normal) @@ -376,54 +635,47 @@ void calcCanonicalRepresentation(CapabilityDef* def, const List<CapabilityDef*>& CapabilityDisjunction exprVal; for (auto& c : def->expr.conjunctions) { - CapabilityDisjunction evalD = evaluateConjunction(c.atoms); + CapabilityDisjunction evalD = evaluateConjunction(sink, c.sourceLoc, *def->sharedContext, c.atoms); + List<CapabilityConjunction> toAddAfter; for (auto& cc : evalD.conjunctions) - exprVal.addConjunction(cc); + { + exprVal.inclusiveJoinConjunction(*def->sharedContext, cc, toAddAfter); + } + for (auto& i : toAddAfter) + exprVal.conjunctions.add(i); + if (toAddAfter.getCount() > 0) + exprVal.removeImplied(); } - disjunction = disjunction.joinWith(exprVal); + disjunction = disjunction.joinWith(sink, def->sourceLoc, *def->sharedContext, exprVal); def->canonicalRepresentation = disjunction.canonicalize(); + def->fillKeyAtomsPresentInCannonicalRepresentation(); } -void calcCanonicalRepresentations(const List<RefPtr<CapabilityDef>>& defs, const List<CapabilityDef*>& mapEnumValueToDef) +void calcCanonicalRepresentations(DiagnosticSink* sink, List<RefPtr<CapabilityDef>>& defs, const List<CapabilityDef*>& mapEnumValueToDef) { for (auto def : defs) - calcCanonicalRepresentation(def, mapEnumValueToDef); -} - -const Index kUnusedEnumValue = -1; - -// Check if "def" uses a named ("name"/enumValueOfAbstract) abstract atom. If true, assign -// the enumValue of the found abstract atom (cache the unique ID) and increment counter. -bool maybeProcessConcreteAtomForAbstractCapability(CapabilityDef* def, Index& enumValueOfAbstract, const String& name, Index& counter) -{ - if (def->getAbstractBase() - && ( - (enumValueOfAbstract == def->getAbstractBase()->enumValue) - || - (enumValueOfAbstract == kUnusedEnumValue && def->getAbstractBase()->name.equals(name)) - ) - ) - { - counter++; - enumValueOfAbstract = def->getAbstractBase()->enumValue; - return true; - } - return false; + calcCanonicalRepresentation(sink, def, mapEnumValueToDef); } void outputUIntSetAsBufferValues(const String& nameOfBuffer, StringBuilder& resultBuilder, UIntSet& set) { // store UIntSet::Element as uint8_t to stay sizeof(UIntSet::Element) independent. // underlying type may change, bits stay the same. - resultBuilder << "const static CapabilityAtomSet " << nameOfBuffer << " = CapabilityAtomSet({\n"; - for (auto i : set.getBuffer()) + resultBuilder << "inline static CapabilityAtomSet generate_" << nameOfBuffer << "()\n"; + resultBuilder << "{\n"; + resultBuilder << " CapabilityAtomSet generatedSet;\n"; + + for (Index i = 0; i < set.getBuffer().getCount(); i++) { - resultBuilder << " UIntSet::Element(" << i << "U),\n"; + resultBuilder << " generatedSet.addRawElement(UIntSet::Element(" << set.getBuffer()[i] << "), " << i << ");\n"; } - resultBuilder << " 0\n});\n"; + resultBuilder << " return generatedSet;\n"; + resultBuilder << "}\n"; + + resultBuilder << "const static CapabilityAtomSet " << nameOfBuffer << " = generate_" << nameOfBuffer << "();\n"; } -SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringBuilder& sbHeader, StringBuilder& sbCpp) +SlangResult generateDefinitions(DiagnosticSink* sink, List<RefPtr<CapabilityDef>>& defs, StringBuilder& sbHeader, StringBuilder& sbCpp) { sbHeader << "enum class CapabilityAtom\n{\n"; @@ -482,9 +734,7 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB sbHeader << " Count\n"; sbHeader << "};\n"; - Index enumValueOfTarget = kUnusedEnumValue; Index targetCount = 0; - Index enumValueOfStage = kUnusedEnumValue; Index stageCount = 0; UIntSet anyTargetAtomSet{}; @@ -494,10 +744,16 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB for (auto def : defs) { - if (maybeProcessConcreteAtomForAbstractCapability(def.get(), enumValueOfTarget, "target", targetCount)) + if (def->getAbstractBase() == def->sharedContext->ptrOfTarget) + { + targetCount++; anyTargetAtomSet.add(def->enumValue); - else if (maybeProcessConcreteAtomForAbstractCapability(def.get(), enumValueOfStage, "stage", stageCount)) + } + else if (def->getAbstractBase() == def->sharedContext->ptrOfStage) + { + stageCount++; anyStageAtomSet.add(def->enumValue); + } } outputUIntSetAsBufferValues("kAnyTargetUIntSetBuffer", anyTargetUIntSetHash, anyTargetAtomSet); outputUIntSetAsBufferValues("kAnyStageUIntSetBuffer", anyStageUIntSetHash, anyStageAtomSet); @@ -507,8 +763,7 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB sbHeader << " kCapabilityStageCount = " << stageCount << ",\n"; sbHeader << "};\n\n"; - calcCanonicalRepresentations(defs, mapEnumValueToDef); - + calcCanonicalRepresentations(sink, defs, mapEnumValueToDef); List<String> capabiltiyNameArray; List<SerializedArrayView> serializedCapabilityArrays; @@ -554,7 +809,7 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB result.count = conjunctions.getCount(); return result; }; - for (auto& def : defs) + for (auto def : defs) { List<SerializedArrayView> conjunctions; for (auto& c : def->canonicalRepresentation) @@ -589,7 +844,7 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB sbCpp << "};\n"; sbCpp << "static const CapabilityAtomInfo kCapabilityNameInfos[int(CapabilityName::Count)] = {\n"; - for (auto def : mapEnumValueToDef) + for (auto* def : mapEnumValueToDef) { if (!def) { @@ -636,7 +891,7 @@ SlangResult generateDefinitions(const List<RefPtr<CapabilityDef>>& defs, StringB } -SlangResult parseDefFile(DiagnosticSink* sink, String inputPath, List<RefPtr<CapabilityDef>>& outDefs) +SlangResult parseDefFile(DiagnosticSink* sink, String inputPath, List<RefPtr<CapabilityDef>>& outDefs, CapabilitySharedContext& capabilitySharedContext) { auto sourceManager = sink->getSourceManager(); @@ -651,7 +906,7 @@ SlangResult parseDefFile(DiagnosticSink* sink, String inputPath, List<RefPtr<Cap namePool.setRootNamePool(&rootPool); lexer.initialize(sourceView, sink, &namePool, sourceManager->getMemoryArena()); - CapabilityDefParser parser(&lexer, sink); + CapabilityDefParser parser(&lexer, sink, capabilitySharedContext); SLANG_RETURN_ON_FAIL(parser.parseDefs()); outDefs = _Move(parser.m_defs); @@ -708,15 +963,17 @@ int main(int argc, const char* const* argv) sourceManager.initialize(nullptr, OSFileSystem::getExtSingleton()); DiagnosticSink sink(&sourceManager, nullptr); List<RefPtr<CapabilityDef>> defs; - if (SLANG_FAILED(parseDefFile(&sink, inPath, defs))) + CapabilitySharedContext capabilitySharedContext; + if (SLANG_FAILED(parseDefFile(&sink, inPath, defs, capabilitySharedContext))) { printDiagnostics(&sink); return 1; } StringBuilder sbHeader, sbCpp; - if (SLANG_FAILED(generateDefinitions(defs, sbHeader, sbCpp))) + if (SLANG_FAILED(generateDefinitions(&sink, defs, sbHeader, sbCpp))) { + printDiagnostics(&sink); return 1; } @@ -734,5 +991,6 @@ int main(int argc, const char* const* argv) printDiagnostics(&sink); return 1; } + printDiagnostics(&sink); return 0; } |
