From 508dc3a95de50de4a4d07d0a72a18e40d55b0e2e Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska Date: Tue, 29 Aug 2023 06:05:26 +0800 Subject: Allow bitwise or expressions and numeric literals in spirv_asm blocks (#3157) * Add -spirv-core-grammar option to load alternate spirv defs Also embed a version to use by default * Use perfect hash for spv op lookup * Neaten perfect hash embedding * Refactor spirv grammar lookup in preperation for more kinds of lookups * Load spirv capability list from spec * Add all SPIR-V enums to lookup table * regenerate vs projects * appease msvc * Use string slices for spir-v core grammar lookups * wiggle * comment * Add OpInfo for spv ops * regenerate vs projects * Embed op names * Add min/max operand counts and enum categories to spirv info * neaten * Operand kinds for spirv ops * Store and embed all information relating to spirv enums and qualifiers * Use SPIR-V spec to position instructions in spirv_asm blocks * Neaten spir-v info embedding * Neaten perfect hash embedding * Add assignment syntax to spirv_asm snippets * Better errors for spirv_asm parser * Add warning for too many operands in spirv asm * squash warnings * neaten * test wiggle * Lookup enums for spirv * Put OpCapability and OpExtension in the correct place for spirv_asm blocks * Tests for OpCapability and OpExtension * ci wiggle * Add expected failure * Allow raising immediate values to constant ids where necessary in spirv_asm blocks * Allow bitwise or expressions and numeric literals in spirv_asm blocks * test numeric literals * Fix memory issues. * fix. --------- Co-authored-by: Yong He --- .../lookup-generator-main.cpp | 222 +--------- .../spirv-embed-generator-main.cpp | 471 +++++++++++++++++++++ 2 files changed, 490 insertions(+), 203 deletions(-) create mode 100644 tools/slang-spirv-embed-generator/spirv-embed-generator-main.cpp (limited to 'tools') diff --git a/tools/slang-lookup-generator/lookup-generator-main.cpp b/tools/slang-lookup-generator/lookup-generator-main.cpp index f64e2a78f..a6f43c102 100644 --- a/tools/slang-lookup-generator/lookup-generator-main.cpp +++ b/tools/slang-lookup-generator/lookup-generator-main.cpp @@ -4,6 +4,7 @@ #include "../../source/compiler-core/slang-json-parser.h" #include "../../source/compiler-core/slang-json-value.h" #include "../../source/compiler-core/slang-lexer.h" +#include "../../source/compiler-core/slang-perfect-hash.h" #include "../../source/core/slang-io.h" #include "../../source/core/slang-secure-crt.h" #include "../../source/core/slang-string-util.h" @@ -68,147 +69,13 @@ static List extractOpNames(UnownedStringSlice& error, const JSONValue& v return opnames; } -struct HashParams -{ - List saltTable; - List destTable; -}; - -enum HashFindResult { - Success, - NonUniqueKeys, - UnavoidableHashCollision, -}; - -// Implemented according to "Hash, displace, and compress" -// https://cmph.sourceforge.net/papers/esa09.pdf -static HashFindResult minimalPerfectHash(const List& ss, HashParams& hashParams) -{ - // Check for uniqueness - for (Index i = 0; i < ss.getCount(); ++i) - { - for (Index j = i + 1; j < ss.getCount(); ++j) - { - if (ss[i] == ss[j]) - { - return NonUniqueKeys; - } - } - } - - SLANG_ASSERT(UIndex(ss.getCount()) < std::numeric_limits::max()); - const UInt32 nBuckets = UInt32(ss.getCount()); - List> initialBuckets; - initialBuckets.setCount(nBuckets); - - const auto hash = [&](const String& s, const HashCode64 salt = 0) -> UInt32 - { - // - // The current getStableHashCode is susceptible to patterns of - // collisions causing the search to fail for the SPIR-V opnames; it - // performs poorly on short strings, taking over 300000 iterations to - // diverge on "Ceil" and "FMix" (and place them in already unoccupied - // slots)! - // - // Use FNV Hash here which seem perform much better on these short inputs - // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function - // - // If you change this, don't forget to also sync the version below in - // the printing code. - UInt64 h = salt; - for (const char c : s) h = ((h * 0x00000100000001B3) ^ c); - return h % nBuckets; - }; - - // Assign the inputs into their buckets according to the hash without salt. - // Sort the buckets according to size, so that later we can make these have - // unique destinations starting with the largest ones first as they are at - // most risk of collision. - for (const auto& s : ss) - { - initialBuckets[hash(s)].add(s); - } - initialBuckets.stableSort([](const List& a, const List& b) { return a.getCount() > b.getCount(); }); - - // These are our outputs, the salts are calculated such that for all input - // word, x, hash(x, salt[hash(x, 0)]) is unique - // - // We keep the final table as we need to detect when we've been given a - // word not in our language. - hashParams.saltTable.setCount(nBuckets); - for (auto& s : hashParams.saltTable) - { - s = 0; - } - hashParams.destTable.setCount(nBuckets); - for (auto& s : hashParams.destTable) - { - s.reduceLength(0); - } - - // This mask will, in each salt tryout, be used to prevent collisions - // within a single bucket. - List bucketDestinations = List::makeRepeated(false, nBuckets); - - for (const auto& b : initialBuckets) - { - // Break if we've reached the empty buckets - if (!b.getCount()) - { - break; - } - - // Try out all the salts until we get one which has no internal - // collisions for this bucket and also no collisions with the buckets - // we've processed so far. - UInt32 salt = 1; - while (true) - { - bool collision = false; - for (auto& d : bucketDestinations) - { - d = false; - } - - for (const auto& s : b) - { - const auto i = hash(s, salt); - if (hashParams.destTable[i].getLength() || bucketDestinations[i]) - { - collision = true; - break; - } - bucketDestinations[i] = true; - } - if (!collision) - { - break; - } - salt++; - - // If we fail to find a solution after some massive amount of tries - // it's almost certainly because of some property of the hash - // function and language causing an irresolvable collision. - if (salt > 10000 * nBuckets) - { - return UnavoidableHashCollision; - } - } - for (const auto& s : b) - { - hashParams.saltTable[hash(s)] = salt; - hashParams.destTable[hash(s, salt)] = s; - } - } - return Success; -} - void writeHashFile( const char* const outCppPath, const char* valueType, const char* valuePrefix, const List includes, - const HashParams& hashParams) + const HashParams& hashParams, + const List values) { StringBuilder sb; StringWriter writer(&sb, WriterFlags(0)); @@ -230,68 +97,12 @@ void writeHashFile( w.print("{\n"); w.print("\n"); - w.print("static const unsigned tableSalt[%ld] =", hashParams.saltTable.getCount()); - w.print("{\n "); - for (Index i = 0; i < hashParams.saltTable.getCount(); ++i) - { - const auto salt = hashParams.saltTable[i]; - if (i != hashParams.saltTable.getCount() - 1) - { - w.print(" %d,", salt); - if (i % 16 == 15) - { - w.print("\n "); - } - } - else - { - w.print(" %d", salt); - } - } - w.print("\n};\n"); - w.print("\n"); - - w.print("struct KV\n"); - w.print("{\n"); - w.print(" const char* name;\n"); - w.print(" %s value;\n", valueType); - w.print("};\n"); - w.print("\n"); - - w.print("static const KV words[%ld] =\n", hashParams.destTable.getCount()); - w.print("{\n"); - for (const auto& s : hashParams.destTable) - { - w.print(" {\"%s\", %s%s},\n", s.getBuffer(), valuePrefix, s.getBuffer()); - } - w.print("};\n"); - w.print("\n"); - - // Make sure to update the hash function in the search function above if - // you change this. - w.print("static UInt32 hash(const UnownedStringSlice& str, UInt32 salt)\n"); - w.print("{\n"); - w.print(" UInt64 h = salt;\n"); - w.print(" for(const char c : str)\n"); - w.print(" h = ((h * 0x00000100000001B3) ^ c);\n"); - w.print(" return h %% (sizeof(tableSalt)/sizeof(tableSalt[0]));\n"); - w.print("}\n"); - w.print("\n"); - - w.print("bool lookup%s(const UnownedStringSlice& str, %s& value)\n", valueType, valueType); - w.print("{\n"); - w.print(" const auto i = hash(str, tableSalt[hash(str, 0)]);\n"); - w.print(" if(str == words[i].name)\n"); - w.print(" {\n"); - w.print(" value = words[i].value;\n"); - w.print(" return true;\n"); - w.print(" }\n"); - w.print(" else\n"); - w.print(" {\n"); - w.print(" return false;\n"); - w.print(" }\n"); - w.print("}\n"); - w.print("\n"); + w.put(perfectHashToEmbeddableCpp( + hashParams, + UnownedStringSlice(valueType), + (String("lookup") + valueType).getUnownedSlice(), + values + ).getBuffer()); w.print("}\n"); @@ -356,10 +167,10 @@ int main(int argc, const char* const* argv) } HashParams hashParams; - auto r = minimalPerfectHash(opnames, hashParams); + auto r = minimalPerfectHash(opnames, hashParams); switch (r) { - case UnavoidableHashCollision: + case HashFindResult::UnavoidableHashCollision: { sink.diagnoseRaw( Severity::Error, @@ -368,20 +179,25 @@ int main(int argc, const char* const* argv) "collision for some input words\n"); return 1; } - case NonUniqueKeys: + case HashFindResult::NonUniqueKeys: { sink.diagnoseRaw(Severity::Error, "Input word list has duplicates\n"); return 1; } - case Success:; + case HashFindResult::Success:; } + List values; + values.reserve (hashParams.destTable.getCount()); + for(const auto& v : hashParams.destTable) + values.add(enumerantPrefix + v); writeHashFile( outCppPath, enumName, enumerantPrefix, { "../core/slang-common.h", "../core/slang-string.h", enumHeader }, - hashParams); + hashParams, + values); return 0; } diff --git a/tools/slang-spirv-embed-generator/spirv-embed-generator-main.cpp b/tools/slang-spirv-embed-generator/spirv-embed-generator-main.cpp new file mode 100644 index 000000000..30f1ef6b9 --- /dev/null +++ b/tools/slang-spirv-embed-generator/spirv-embed-generator-main.cpp @@ -0,0 +1,471 @@ +#include + +#include "source/core/slang-dictionary.h" +#include "source/core/slang-io.h" +#include "source/compiler-core/slang-diagnostic-sink.h" +#include "source/compiler-core/slang-perfect-hash.h" +#include "source/core/slang-writer.h" +#include "source/compiler-core/slang-spirv-core-grammar.h" +#include "source/compiler-core/slang-lexer.h" + +using namespace Slang; + +// +// Go from a dictionary to a C++ embedding of a perfect hash +// +template +String dictToPerfectHash( + const Dictionary& dict, + const UnownedStringSlice& type, + const UnownedStringSlice& funcName, + F valueToString) +{ + HashParams hashParams; + List names; + for(const auto& [name, val] : dict) + names.add(name); + auto r = minimalPerfectHash(names, hashParams); + SLANG_ASSERT(r == HashFindResult::Success); + List values; + values.reserve(hashParams.destTable.getCount()); + for(const auto& v : hashParams.destTable) + { + values.add(valueToString(dict.getValue(v.getUnownedSlice()))); + } + return perfectHashToEmbeddableCpp(hashParams, type, funcName, values); +} + +// +// Go from a dictionary to a C++ embedding of switch table +// +template +void dictToSwitch( + const Dictionary& dict, + const char* funName, + const char* keyType, + const char* valueType, + const char* unpackKey, + const F1 keyToString, + const F2 valueToAssignmentString, + WriterHelper& w) +{ + const auto line = [&](const auto& l){ + w.put(l); + w.put("\n"); + }; + + w.print("static bool %s(const %s& k, %s& v)\n", funName, keyType, valueType); + line("{"); + w.print(" switch(%s)\n", unpackKey); + line(" {"); + for(const auto& [k, v] : dict) + { + const auto kStr = keyToString(k); + const auto vStr = valueToAssignmentString(v); + w.print( + " case %s:\n" + " {\n" + " %s;\n" + " return true;\n" + " }\n", + kStr.getBuffer(), + vStr.getBuffer() + ); + } + line(" default: return false;"); + line(" }"); + line("}"); + line(""); +} + +// +// Go from a dictionary to a C++ embedding of switch table, specific to the +// two-level table of a QualifiedEnumValue +// +template +void qualifiedEnumValueNameSwitch( + const Dictionary& dict, + const char* funName, + const char* keyType, + const char* valueType, + const char* unpackKey1, + const F valueToAssignmentString, + WriterHelper& w) +{ + const auto line = [&](const auto& l){ + w.put(l); + w.put("\n"); + }; + + using K1 = Slang::SPIRVCoreGrammarInfo::OperandKind; + using K2 = SpvWord; + Dictionary> stepDict; + for(const auto& [k, v] : dict) + { + const auto& [k1, k2] = k; + stepDict[k1][k2] = v; + } + + w.print("static bool %s(const %s& k, %s& v)\n", funName, keyType, valueType); + line("{"); + line(" const auto& [k1, k2] = k;"); + w.print(" switch(%s)\n", unpackKey1); + line(" {"); + for(const auto& [k1, inner] : stepDict) + { + const auto k1Str = String(k1.index); + w.print(" case %s:\n", k1Str.getBuffer()); + + line(" switch(k2)"); + line(" {"); + for(const auto& [k2, v] : inner) + { + const auto k2Str = String(k2); + const auto vStr = valueToAssignmentString(v); + w.print(" case %s: %s; return true;\n", k2Str.getBuffer(), vStr.getBuffer()); + } + line(" default: return false;"); + line(" }"); + } + line(" default: return false;"); + line(" }"); + line("}"); + line(""); +} + +static const char* opClassToString(Slang::SPIRVCoreGrammarInfo::OpInfo::Class c) +{ + switch(c) + { +#define GO(n) case SPIRVCoreGrammarInfo::OpInfo::n: return #n; + GO(Miscellaneous) + GO(Debug) + GO(Annotation) + GO(Extension) + GO(ModeSetting) + GO(TypeDeclaration) + GO(ConstantCreation) + GO(Memory) + GO(Function) + GO(Image) + GO(Conversion) + GO(Composite) + GO(Arithmetic) + GO(Bit) + GO(Relational_and_Logical) + GO(Derivative) + GO(ControlFlow) + GO(Atomic) + GO(Primitive) + GO(Barrier) + GO(Group) + GO(DeviceSideEnqueue) + GO(Pipe) + GO(NonUniform) + GO(Reserved) + default: + GO(Other) +#undef GO + } +} + +// +// Write a C++ embedding of the SPIRVCoreGrammarInfo struct +// +void writeInfo( + const char* const outCppPath, + const SPIRVCoreGrammarInfo& info) +{ + StringBuilder sb; + StringWriter writer(&sb, WriterFlags(0)); + WriterHelper w(&writer); + const auto line = [&](const auto& l){ + w.put(l); + w.put("\n"); + }; + + // + // Intro + // + line("// Source embedding for SPIR-V core grammar"); + line("//"); + line("// This file was carefully generated by a machine,"); + line("// don't even think about modifying it yourself!"); + line("//"); + line(""); + line("#include \"../core/slang-smart-pointer.h\""); + line("#include \"../compiler-core/slang-spirv-core-grammar.h\""); + line("namespace Slang"); + line("{"); + line("using OperandKind = SPIRVCoreGrammarInfo::OperandKind;"); + line("using QualifiedEnumName = SPIRVCoreGrammarInfo::QualifiedEnumName;"); + line("using QualifiedEnumValue = SPIRVCoreGrammarInfo::QualifiedEnumValue;"); + + // + // Each block writes the lookup function for a member table + // Read the memberAssignments addition to see which one + // + List memberAssignments; + + + { + memberAssignments.add("info->opcodes.embedded = &lookupSpvOp;"); + w.put("static "); + w.put(dictToPerfectHash( + info.opcodes.dict, + UnownedStringSlice("SpvOp"), + UnownedStringSlice("lookupSpvOp"), + [](const auto n){ + const auto radix = 10; + return "static_cast(" + String(n, radix) + ")"; + } + ).getBuffer()); + } + + { + memberAssignments.add("info->capabilities.embedded = &lookupSpvCapability;"); + w.put("static "); + w.put(dictToPerfectHash( + info.capabilities.dict, + UnownedStringSlice("SpvCapability"), + UnownedStringSlice("lookupSpvCapability"), + [](const auto n){ + const auto radix = 10; + return "static_cast(" + String(n, radix) + ")"; + } + ).getBuffer()); + } + + { + memberAssignments.add("info->allEnumsWithTypePrefix.embedded = &lookupEnumWithTypePrefix;"); + w.put("static "); + w.put(dictToPerfectHash( + info.allEnumsWithTypePrefix.dict, + UnownedStringSlice("SpvWord"), + UnownedStringSlice("lookupEnumWithTypePrefix"), + [](const auto n){ + const auto radix = 10; + return "SpvWord{" + String(n, radix) + "}"; + } + ).getBuffer()); + } + + { + memberAssignments.add("info->opInfos.embedded = &getOpInfo;"); + dictToSwitch( + info.opInfos.dict, + "getOpInfo", + "SpvOp", + "SPIRVCoreGrammarInfo::OpInfo", + "k", + [&](SpvOp o){ + return "Spv" + String(info.opNames.dict.getValue(o)); + }, + [](const Slang::SPIRVCoreGrammarInfo::OpInfo& i){ + const char* classStr = opClassToString(i.class_); + String ret; + if(i.numOperandTypes) + { + ret.append("const static OperandKind operandTypes[] = {"); + String operandTypes; + for(Index o = 0; o < i.numOperandTypes; ++o) + { + if(o != 0) + ret.append(", "); + ret.append("{" + String(i.operandTypes[o].index) + "}"); + } + ret.append("};\n "); + } + ret.append( + String("v = {SPIRVCoreGrammarInfo::OpInfo::") + + classStr + ", " + + String(i.resultTypeIndex) + ", " + + String(i.resultIdIndex) + ", " + + String(i.minOperandCount) + ", " + + (i.maxOperandCount == 0xffff ? String("0xffff") : String(i.maxOperandCount)) + ", " + + String(i.numOperandTypes) + ", " + + (i.numOperandTypes ? "operandTypes" : "nullptr") + + "}"); + return ret; + }, + w + ); + } + + { + memberAssignments.add("info->opNames.embedded = &getOpName;"); + dictToSwitch( + info.opNames.dict, + "getOpName", + "SpvOp", + "UnownedStringSlice", + "k", + [&](SpvOp o){ + return "Spv" + String(info.opNames.dict.getValue(o)); + }, + [](const UnownedStringSlice& i){ + return "v = UnownedStringSlice{\"" + String(i) + "\"}"; + }, + w + ); + } + + { + memberAssignments.add("info->operandKinds.embedded = &lookupOperandKind;"); + w.put("static "); + w.put(dictToPerfectHash( + info.operandKinds.dict, + UnownedStringSlice("OperandKind"), + UnownedStringSlice("lookupOperandKind"), + [](const auto n){ + const auto radix = 10; + return "OperandKind{" + String(n.index, radix) + "}"; + } + ).getBuffer()); + } + + { + memberAssignments.add("info->allEnums.embedded = &lookupQualifiedEnum;"); + + // First construct a helper function which will lookup an enum name + // with a hex prefix representing the kind. This allows us to just + // reuse the existing string-based perfect hasher + Dictionary enumDict; + Index maxNameLength = 0; + for(const auto& [q, v] : info.allEnums.dict) + { + const auto i = q.kind.index; + String k; + k.appendChar(char((i >> 4) + 'a')); + k.appendChar(char((i & 0xf) + 'a')); + k.append(q.name); + enumDict.add(k, v); + maxNameLength = std::max(maxNameLength, k.getLength()); + } + w.put(dictToPerfectHash( + enumDict, + UnownedStringSlice("SpvWord"), + UnownedStringSlice("lookupEnumWithHexPrefix"), + [&](const auto n){ return "SpvWord{" + String(n) + "}"; } + ).getBuffer()); + + // Utilise this helper + line("static bool lookupQualifiedEnum(const QualifiedEnumName& k, SpvWord& v)"); + line("{"); + line(" static_assert(sizeof(k.kind.index) == 1);"); + w.print(" if(k.name.getLength() > %ld)\n", maxNameLength); + line(" return false;"); + w.print(" char name[%ld];\n", maxNameLength + 2); + line(" name[0] = char((k.kind.index >> 4) + 'a');"); + line(" name[1] = char((k.kind.index & 0xf) + 'a');"); + line(" memcpy(name+2, k.name.begin(), k.name.getLength());"); + line(" return lookupEnumWithHexPrefix(UnownedStringSlice(name, k.name.getLength() + 2), v);"); + line("}"); + line(""); + } + + { + memberAssignments.add("info->allEnumNames.embedded = &getQualifiedEnumName;"); + qualifiedEnumValueNameSwitch( + info.allEnumNames.dict, + "getQualifiedEnumName", + "QualifiedEnumValue", + "UnownedStringSlice", + "k1.index", + [](const UnownedStringSlice& i){ + return "v = UnownedStringSlice{\"" + String(i) + "\"}"; + }, + w + ); + } + + { + memberAssignments.add("info->operandKindNames.embedded = &getOperandKindName;"); + dictToSwitch( + info.operandKindNames.dict, + "getOperandKindName", + "OperandKind", + "UnownedStringSlice", + "k.index", + [&](Slang::SPIRVCoreGrammarInfo::OperandKind o){ + return String(o.index); + }, + [](const UnownedStringSlice& i){ + return "v = UnownedStringSlice{\"" + String(i) + "\"}"; + }, + w + ); + } + + { + memberAssignments.add("info->operandKindUnderneathIds.embedded = &getOperandKindUnderneathId;"); + dictToSwitch( + info.operandKindUnderneathIds.dict, + "getOperandKindUnderneathId", + "OperandKind", + "OperandKind", + "k.index", + [](Slang::SPIRVCoreGrammarInfo::OperandKind o){ + return String(o.index); + }, + [](Slang::SPIRVCoreGrammarInfo::OperandKind i){ + return "v = OperandKind{" + String(i.index) + "}"; + }, + w + ); + } + + // + // Now write out the function which holds onto the static embedded info table + // + line("RefPtr SPIRVCoreGrammarInfo::getEmbeddedVersion()"); + line("{"); + line(" static RefPtr embedded = [](){"); + line(" RefPtr info = new SPIRVCoreGrammarInfo();"); + for(const auto& a : memberAssignments) + line((" " + a).getBuffer()); + + // + line(" return info;"); + line(" }();"); + line(" return embedded;"); + line("}"); + line("}"); + + File::writeAllTextIfChanged(outCppPath, sb.getUnownedSlice()); +} + +int main(int argc, const char* const* argv) +{ + using namespace Slang; + + if (argc != 3) + { + fprintf( + stderr, + "Usage: %s spirv.core.grammar.json output.cpp\n", + argc >= 1 ? argv[0] : "slang-spirv-embed-generator"); + return 1; + } + + const char* const inPath = argv[1]; + const char* const outCppPath = argv[2]; + + RefPtr writer(new FileWriter(stderr, WriterFlag::AutoFlush)); + SourceManager sourceManager; + sourceManager.initialize(nullptr, nullptr); + DiagnosticSink sink(&sourceManager, Lexer::sourceLocationLexer); + sink.writer = writer; + + String contents; + SLANG_RETURN_ON_FAIL(File::readAllText(inPath, contents)); + PathInfo pathInfo = PathInfo::makeFromString(inPath); + SourceFile* sourceFile = sourceManager.createSourceFileWithString(pathInfo, contents); + SourceView* sourceView = sourceManager.createSourceView(sourceFile, nullptr, SourceLoc()); + + RefPtr info = SPIRVCoreGrammarInfo::loadFromJSON(*sourceView, sink); + + writeInfo(outCppPath, *info); + + return 0; +} -- cgit v1.2.3