diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2020-08-31 13:02:55 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-08-31 13:02:55 -0400 |
| commit | 69025ad82238a7402b18d9c566fac1574faef684 (patch) | |
| tree | ab01e4248071f9f597aa04f15742852a7662ed23 /source | |
| parent | baa789e0c9109bcb1e717ce4a9953709e7345e55 (diff) | |
AST Serialization in Modules (#1524)
* First pass at filter for AST serial writing.
* Serialization of AST for modules.
* Removed some commented out source.
Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-decl.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ast-dump.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ast-serialize.cpp | 251 | ||||
| -rw-r--r-- | source/slang/slang-ast-serialize.h | 50 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 7 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.cpp | 38 | ||||
| -rw-r--r-- | source/slang/slang-ir-serialize-types.h | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-serialize.cpp | 37 | ||||
| -rw-r--r-- | source/slang/slang-ir-serialize.h | 3 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 15 |
10 files changed, 389 insertions, 57 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index e88260414..bd7d3a4f1 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -366,6 +366,8 @@ class ModuleDecl : public NamespaceDeclBase // Module* module = nullptr; + SLANG_UNREFLECTED + /// Map a type to the list of extensions of that type (if any) declared in this module /// /// This mapping is filled in during semantic checking, as `ExtensionDecl`s get checked. @@ -479,4 +481,16 @@ class AttributeDecl : public ContainerDecl SyntaxClass<NodeBase> syntaxClass; }; +// Import a Declaration that has been defined 'externally' and referenced in this module. +// Includes the managed name of the declaration. +// This declaration can be added when a module is serialized to any declaration that is +// external to the module being serialized, such that it replaces the reference to the +// declaration in another module. +class ImportExternalDecl : public DeclBase +{ + SLANG_CLASS(ImportExternalDecl) + + String mangledName; +}; + } // namespace Slang diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index 806a9146b..a821729a6 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -218,6 +218,7 @@ struct ASTDumpContext { if (m_dumpFlags & ASTDumpUtil::Flag::HideSourceLoc) { + ScopeWrite(this).getBuf() << "SourceLoc(0)"; return; } @@ -268,7 +269,7 @@ struct ASTDumpContext { ScopeWrite(this).getBuf() << " { " << TokenTypeToString(token.type) << ", "; dump(token.loc); - m_writer->emit(" "); + m_writer->emit(", "); dump(token.getContent()); m_writer->emit(" }"); } @@ -443,11 +444,11 @@ struct ASTDumpContext { if (qualType.isLeftValue) { - m_writer->emit("left "); + m_writer->emit("lvalue "); } else { - m_writer->emit("right "); + m_writer->emit("rvalue "); } dump(qualType.type); } @@ -580,8 +581,10 @@ struct ASTDumpContext void dump(ASTNodeType nodeType) { - SLANG_UNUSED(nodeType) - // Don't bother to output anything - as will already have been dumped with the object name + // Get the class + auto info = ReflectClassInfo::getInfo(nodeType); + // Write the name + m_writer->emit(info->m_name); } void dumpObjectFull(NodeBase* node); diff --git a/source/slang/slang-ast-serialize.cpp b/source/slang/slang-ast-serialize.cpp index 6eaa29968..22d5c3de2 100644 --- a/source/slang/slang-ast-serialize.cpp +++ b/source/slang/slang-ast-serialize.cpp @@ -8,9 +8,12 @@ #include "slang-type-layout.h" #include "slang-ast-dump.h" +#include "slang-mangle.h" #include "slang-ast-support-types.h" +#include "slang-legalize-types.h" + #include "../core/slang-byte-encode-util.h" namespace Slang { @@ -46,6 +49,60 @@ static void _toNativeValue(ASTSerialReader* reader, const SERIAL_TYPE& src, NATI ASTSerialTypeInfo<NATIVE_TYPE>::toNative(reader, &src, &dst); } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ModuleASTSerialFilter !!!!!!!!!!!!!!!!!!!!!!!! + + +ASTSerialIndex ModuleASTSerialFilter::writePointer(ASTSerialWriter* writer, const NodeBase* inPtr) +{ + NodeBase* ptr = const_cast<NodeBase*>(inPtr); + SLANG_ASSERT(ptr); + + if (Decl* decl = as<Decl>(ptr)) + { + ModuleDecl* moduleDecl = findModuleForDecl(decl); + SLANG_ASSERT(moduleDecl); + if (moduleDecl && moduleDecl != m_moduleDecl) + { + ASTBuilder* astBuilder = m_moduleDecl->module->getASTBuilder(); + + // It's a reference to a declaration in another module, so create an ImportExternalDecl. + + String mangledName = getMangledName(astBuilder, decl); + + ImportExternalDecl* importDecl = astBuilder->create<ImportExternalDecl>(); + importDecl->mangledName = mangledName; + const ASTSerialIndex index = writer->writePointer(importDecl); + + // Set as the index of this + writer->setPointerIndex(ptr, index); + + return index; + } + else + { + // Okay... we can just write it out then + return writer->writePointer(ptr); + } + } + + // TODO(JS): What we really want to do here is to ignore bodies functions. + // It's not 100% clear if this is even right though - for example does type inference + // imply the body is needed to say infer a return type? + // Also not clear if statements in other scenarios (if there are others) might need to be kept. + // + // For now we just ignore all stmts + + if (Stmt* stmt = as<Stmt>(ptr)) + { + // + writer->setPointerIndex(stmt, ASTSerialIndex(0)); + return ASTSerialIndex(0); + } + + // For now for everything else just write it + return writer->writePointer(ptr); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Serial <-> Native conversion !!!!!!!!!!!!!!!!!!!!!!!! @@ -1030,28 +1087,20 @@ ASTSerialClasses::ASTSerialClasses(): // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerialWriter !!!!!!!!!!!!!!!!!!!!!!!!!!!! -ASTSerialWriter::ASTSerialWriter(ASTSerialClasses* classes) : +ASTSerialWriter::ASTSerialWriter(ASTSerialClasses* classes, ASTSerialFilter* filter) : m_arena(2048), - m_classes(classes) + m_classes(classes), + m_filter(filter) { // 0 is always the null pointer m_entries.add(nullptr); m_ptrMap.Add(nullptr, 0); } -ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node) +ASTSerialIndex ASTSerialWriter::writePointer(const NodeBase* node) { - // Null is always 0 - if (node == nullptr) - { - return ASTSerialIndex(0); - } - // Look up in the map - Index* indexPtr = m_ptrMap.TryGetValue(node); - if (indexPtr) - { - return ASTSerialIndex(*indexPtr); - } + // This pointer cannot be in the map + SLANG_ASSERT(m_ptrMap.TryGetValue(node) == nullptr); const ASTSerialClass* serialClass = m_classes->getSerialClass(node->astNodeType); @@ -1072,6 +1121,7 @@ ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node) for (Index i = 0; i < serialClass->fieldsCount; ++i) { auto field = serialClass->fields[i]; + // Work out the offsets auto srcField = ((const uint8_t*)node) + field.nativeOffset; auto dstField = serialPayload + field.serialOffset; @@ -1088,6 +1138,35 @@ ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node) return index; } +void ASTSerialWriter::setPointerIndex(const NodeBase* ptr, ASTSerialIndex index) +{ + m_ptrMap.Add(ptr, Index(index)); +} + +ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node) +{ + // Null is always 0 + if (node == nullptr) + { + return ASTSerialIndex(0); + } + // Look up in the map + Index* indexPtr = m_ptrMap.TryGetValue(node); + if (indexPtr) + { + return ASTSerialIndex(*indexPtr); + } + + if (m_filter) + { + return m_filter->writePointer(this, node); + } + else + { + return writePointer(node); + } +} + ASTSerialIndex ASTSerialWriter::addPointer(const RefObject* obj) { // Null is always 0 @@ -1246,6 +1325,8 @@ ASTSerialIndex ASTSerialWriter::_addArray(size_t elementSize, size_t alignment, return ASTSerialIndex(m_entries.getCount() - 1); } +static const uint8_t s_fixBuffer[ASTSerialInfo::MAX_ALIGNMENT]{ 0, }; + SlangResult ASTSerialWriter::write(Stream* stream) { const Int entriesCount = m_entries.getCount(); @@ -1263,8 +1344,6 @@ SlangResult ASTSerialWriter::write(Stream* stream) // knowing that removeLast cannot release memory, means the sentinal must be at the last position. entries[entriesCount] = &sentinal; - - static const uint8_t fixBuffer[ASTSerialInfo::MAX_ALIGNMENT] { 0, }; { size_t offset = 0; @@ -1303,7 +1382,7 @@ SlangResult ASTSerialWriter::write(Stream* stream) // If we needed to fix so that subsequent alignment is right, write out extra bytes here if (alignmentFixSize) { - stream->write(fixBuffer, alignmentFixSize); + stream->write(s_fixBuffer, alignmentFixSize); } } catch (const IOException&) @@ -1320,6 +1399,80 @@ SlangResult ASTSerialWriter::write(Stream* stream) return SLANG_OK; } +SlangResult ASTSerialWriter::writeIntoContainer(RiffContainer* container) +{ + typedef RiffContainer::Chunk Chunk; + typedef RiffContainer::ScopeChunk ScopeChunk; + + // This is the container for the AST Data + ScopeChunk scopeModule(container, Chunk::Kind::List, ASTSerialBinary::kSlangASTModuleFourCC); + { + ScopeChunk scopeData(container, Chunk::Kind::Data, ASTSerialBinary::kSlangASTModuleDataFourCC); + + { + // Sentinal so we don't need special handling for end of list + ASTSerialInfo::Entry sentinal; + sentinal.type = ASTSerialInfo::Type::String; + sentinal.info = ASTSerialInfo::EntryInfo::Alignment1; + + size_t offset = 0; + const Int entriesCount = m_entries.getCount(); + + { + m_entries.add(&sentinal); + m_entries.removeLast(); + // Note strictly required in our impl of List. But by writing this and + // knowing that removeLast cannot release memory, means the sentinal must be at the last position. + m_entries.getBuffer()[entriesCount] = &sentinal; + } + + ASTSerialInfo::Entry*const* entries = m_entries.getBuffer(); + + ASTSerialInfo::Entry* entry = entries[1]; + // We start on 1, because 0 is nullptr and not used for anything + for (Index i = 1; i < entriesCount; ++i) + { + ASTSerialInfo::Entry* next = entries[i + 1]; + + // Before writing we need to store the next alignment + + const size_t nextAlignment = ASTSerialInfo::getAlignment(next->info); + const size_t alignment = ASTSerialInfo::getAlignment(entry->info); + + entry->info = ASTSerialInfo::combineWithNext(entry->info, next->info); + + // Check we are aligned correctly + SLANG_ASSERT((offset & (alignment - 1)) == 0); + + // When we write, we need to make sure it take into account the next alignment + const size_t entrySize = entry->calcSize(m_classes); + + // Work out the fix for next alignment + size_t nextOffset = offset + entrySize; + nextOffset = (nextOffset + nextAlignment - 1) & ~(nextAlignment - 1); + + size_t alignmentFixSize = nextOffset - (offset + entrySize); + + // The fix must be less than max alignment. We require it to be less because we aligned each Entry to + // MAX_ALIGNMENT, and so < MAX_ALIGNMENT is the most extra bytes we can write + SLANG_ASSERT(alignmentFixSize < ASTSerialInfo::MAX_ALIGNMENT); + + container->write(entry, entrySize); + if (alignmentFixSize) + { + container->write(s_fixBuffer, alignmentFixSize); + } + + // Onto next + offset = nextOffset; + entry = next; + } + } + } + + return SLANG_OK; +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerialInfo::Entry !!!!!!!!!!!!!!!!!!!!!!!! size_t ASTSerialInfo::Entry::calcSize(ASTSerialClasses* serialClasses) const @@ -1713,6 +1866,62 @@ SlangResult ASTSerialReader::load(const uint8_t* data, size_t dataCount, ASTBuil } +/* static */Result ASTSerialReader::readContainerModules(RiffContainer* container, Linkage* linkage, List<RefPtr<Module>>& outModules) +{ + List<RiffContainer::ListChunk*> moduleChunks; + // First try to find a list + { + RiffContainer::ListChunk* listChunk = container->getRoot()->findListRec(SerialBinary::kSlangModuleListFourCc); + if (listChunk) + { + listChunk->findContained(ASTSerialBinary::kSlangASTModuleFourCC, moduleChunks); + } + else + { + // Maybe its just a single module + RiffContainer::ListChunk* moduleChunk = container->getRoot()->findListRec(ASTSerialBinary::kSlangASTModuleFourCC); + if (!moduleChunk) + { + // Couldn't find any modules + return SLANG_FAIL; + } + moduleChunks.add(moduleChunk); + } + } + + RefPtr<ASTSerialClasses> serialClasses(new ASTSerialClasses); + + // Okay, deserialize the each of the module chunks + for (RiffContainer::ListChunk* listChunk : moduleChunks) + { + // Look for the module data + auto data = listChunk->findContainedData(ASTSerialBinary::kSlangASTModuleDataFourCC); + + if (!data) + { + return SLANG_FAIL; + } + + ASTSerialReader reader(serialClasses); + + RefPtr<Module> module(new Module(linkage)); + SLANG_RETURN_ON_FAIL(reader.load((uint8_t*)data->getPayload(), data->getSize(), module->getASTBuilder(), linkage->getNamePool())); + + ModuleDecl* moduleDecl = reader.getPointer(ASTSerialIndex(1)).dynamicCast<ModuleDecl>(); + if (!moduleDecl) + { + return SLANG_FAIL; + } + + // Set on the module + module->setModuleDecl(moduleDecl); + + outModules.add(module); + } + + return SLANG_OK; +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerializeUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!! /* static */SlangResult ASTSerialTestUtil::selfTest() @@ -1762,7 +1971,12 @@ SlangResult ASTSerialReader::load(const uint8_t* data, size_t dataCount, ASTBuil { OwnedMemoryStream stream(FileAccess::ReadWrite); - ASTSerialWriter writer(classes); + ModuleDecl* moduleDecl = as<ModuleDecl>(node); + ModuleASTSerialFilter filterStorage(moduleDecl); + + ASTSerialFilter* filter = moduleDecl ? &filterStorage : nullptr; + + ASTSerialWriter writer(classes, filter); // Lets serialize it all writer.addPointer(node); @@ -1831,6 +2045,7 @@ SlangResult ASTSerialReader::load(const uint8_t* data, size_t dataCount, ASTBuil File::writeAllText("ast-read.ast-dump", readDump); File::writeAllText("ast-orig.ast-dump", origDump); + if (readDump != origDump) { return SLANG_FAIL; diff --git a/source/slang/slang-ast-serialize.h b/source/slang/slang-ast-serialize.h index c2086f434..a9fa9f605 100644 --- a/source/slang/slang-ast-serialize.h +++ b/source/slang/slang-ast-serialize.h @@ -7,15 +7,18 @@ #include "slang-ast-support-types.h" #include "slang-ast-all.h" +#include "../core/slang-riff.h" + #include "slang-ast-builder.h" #include "../core/slang-byte-encode-util.h" #include "../core/slang-stream.h" - namespace Slang { +class Linkage; + /* AST Serialization Overview ========================== @@ -188,6 +191,16 @@ An extra wrinkle is that we allow accessing of a serialized String as a Name or and a Name remains in scope as long as it's NamePool does which is passed in. */ +/* Holds RIFF FourCC codes for AST types */ +struct ASTSerialBinary +{ + static const FourCC kRiffFourCC = RiffFourCC::kRiff; + + /// AST module LIST container + static const FourCC kSlangASTModuleFourCC = SLANG_FOUR_CC('S', 'A', 'm', 'l'); + /// AST module data + static const FourCC kSlangASTModuleDataFourCC = SLANG_FOUR_CC('S', 'A', 'm', 'd'); +}; class ASTSerialClasses; @@ -364,6 +377,9 @@ public: /// NOTE! data must stay ins scope when reading takes place SlangResult load(const uint8_t* data, size_t dataCount, ASTBuilder* builder, NamePool* namePool); + /// Read the modules from the container + static Result readContainerModules(RiffContainer* container, Linkage* linkage, List<RefPtr<Module>>& outModules); + ASTSerialReader(ASTSerialClasses* classes): m_classes(classes) { @@ -415,6 +431,26 @@ void ASTSerialReader::getArray(ASTSerialIndex index, List<T>& out) class ASTSerialClasses; +class ASTSerialWriter; + +class ASTSerialFilter +{ +public: + virtual ASTSerialIndex writePointer(ASTSerialWriter* writer, const NodeBase* ptr) = 0; +}; + +class ModuleASTSerialFilter : public ASTSerialFilter +{ +public: + virtual ASTSerialIndex writePointer(ASTSerialWriter* writer, const NodeBase* ptr) SLANG_OVERRIDE; + + ModuleASTSerialFilter(ModuleDecl* moduleDecl): + m_moduleDecl(moduleDecl) + { + } + + ModuleDecl* m_moduleDecl; +}; /* This is a class used tby toSerial implementations to turn native type into the serial type */ class ASTSerialWriter : public RefObject @@ -423,6 +459,9 @@ public: ASTSerialIndex addPointer(const NodeBase* ptr); ASTSerialIndex addPointer(const RefObject* ptr); + /// Write the pointer + ASTSerialIndex writePointer(const NodeBase* ptr); + template <typename T> ASTSerialIndex addArray(const T* in, Index count); @@ -431,13 +470,19 @@ public: ASTSerialIndex addName(const Name* name); ASTSerialSourceLoc addSourceLoc(SourceLoc sourceLoc); + /// Set a the index associated with an index. NOTE! That there cannot be a pre-existing setting. + void setPointerIndex(const NodeBase* ptr, ASTSerialIndex index); + /// Get the entries table holding how each index maps to an entry const List<ASTSerialInfo::Entry*>& getEntries() const { return m_entries; } /// Write to a stream SlangResult write(Stream* stream); - ASTSerialWriter(ASTSerialClasses* classes); + /// Write the state into the container + SlangResult writeIntoContainer(RiffContainer* container); + + ASTSerialWriter(ASTSerialClasses* classes, ASTSerialFilter* filter); protected: @@ -461,6 +506,7 @@ protected: List<ASTSerialInfo::Entry*> m_entries; ///< The entries MemoryArena m_arena; ///< Holds the payloads ASTSerialClasses* m_classes; + ASTSerialFilter* m_filter; ///< Filter to control what is serialized }; // --------------------------------------------------------------------------- diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 79d9e4bd7..3ccf6fe06 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -24,6 +24,7 @@ namespace Slang {} void visitDeclGroup(DeclGroup*) {} + void visitImportExternalDecl(ImportExternalDecl*) {} void visitDecl(Decl* decl) { @@ -41,6 +42,8 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} + void visitImportExternalDecl(ImportExternalDecl*) {} + void checkVarDeclCommon(VarDeclBase* varDecl); @@ -109,6 +112,7 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} + void visitImportExternalDecl(ImportExternalDecl*) {} #define CASE(TYPE) void visit##TYPE(TYPE* decl) { checkForRedeclaration(decl); } @@ -130,6 +134,7 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} + void visitImportExternalDecl(ImportExternalDecl*) {} void visitInheritanceDecl(InheritanceDecl* inheritanceDecl); @@ -164,6 +169,7 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} + void visitImportExternalDecl(ImportExternalDecl*) {} void checkVarDeclCommon(VarDeclBase* varDecl); @@ -1081,6 +1087,7 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} + void visitImportExternalDecl(ImportExternalDecl*) {} // Any user-defined type may have declared interface conformances, // which we should check. diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index d0b15c683..53e529cc7 100755 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -23,6 +23,8 @@ #include "slang-glsl-extension-tracker.h" #include "slang-emit-cuda.h" +#include "slang-ast-serialize.h" + #include "slang-ir-serialize.h" // Enable calling through to `fxc` or `dxc` to @@ -2331,7 +2333,7 @@ SlangResult dissassembleDXILUsingDXC( { // Module list - RiffContainer::ScopeChunk listScope(&container, RiffContainer::Chunk::Kind::List, IRSerialBinary::kSlangModuleListFourCc); + RiffContainer::ScopeChunk listScope(&container, RiffContainer::Chunk::Kind::List, SerialBinary::kSlangModuleListFourCc); auto linkage = getLinkage(); auto sink = getSink(); @@ -2344,9 +2346,11 @@ SlangResult dissassembleDXILUsingDXC( optionFlags |= IRSerialWriter::OptionFlag::DebugInfo; } + RefPtr<ASTSerialClasses> astClasses = new ASTSerialClasses; + SourceManager* sourceManager = frontEndReq->getSourceManager(); - for (auto translationUnit : frontEndReq->translationUnits) + for (TranslationUnitRequest* translationUnit : frontEndReq->translationUnits) { auto module = translationUnit->module; auto irModule = module->getIRModule(); @@ -2354,10 +2358,27 @@ SlangResult dissassembleDXILUsingDXC( // Okay, we need to serialize this module to our container file. // We currently don't serialize it's name..., but support for that could be added. - IRSerialData serialData; - IRSerialWriter writer; - SLANG_RETURN_ON_FAIL(writer.write(irModule, sourceManager, optionFlags, &serialData)); - SLANG_RETURN_ON_FAIL(IRSerialWriter::writeContainer(serialData, compressionType, &container)); + // Write the IR information + { + IRSerialData serialData; + IRSerialWriter writer; + SLANG_RETURN_ON_FAIL(writer.write(irModule, sourceManager, optionFlags, &serialData)); + SLANG_RETURN_ON_FAIL(IRSerialWriter::writeContainer(serialData, compressionType, &container)); + } + + // Write the AST information + { + ModuleDecl* moduleDecl = translationUnit->getModuleDecl(); + + ModuleASTSerialFilter filter(moduleDecl); + ASTSerialWriter writer(astClasses, &filter); + + // Add the module and everything that isn't filtered out in the filter. + writer.addPointer(moduleDecl); + + // We can now serialize it into the riff container. + SLANG_RETURN_ON_FAIL(writer.writeIntoContainer(&container)); + } } auto program = getSpecializedGlobalAndEntryPointsComponentType(); @@ -2383,7 +2404,7 @@ SlangResult dissassembleDXILUsingDXC( auto entryPoint = program->getEntryPoint(ii); auto entryPointMangledName = program->getEntryPointMangledName(ii); - RiffContainer::ScopeChunk entryPointScope(&container, RiffContainer::Chunk::Kind::Data, IRSerialBinary::kEntryPointFourCc); + RiffContainer::ScopeChunk entryPointScope(&container, RiffContainer::Chunk::Kind::Data, SerialBinary::kEntryPointFourCc); auto writeString = [&](String const& str) { @@ -2529,7 +2550,8 @@ SlangResult dissassembleDXILUsingDXC( compileRequest, targetReq); } - else { + else + { for (Index ee = 0; ee < entryPointCount; ++ee) { writeEntryPointResult( diff --git a/source/slang/slang-ir-serialize-types.h b/source/slang/slang-ir-serialize-types.h index b6d4572bd..3d006b86e 100644 --- a/source/slang/slang-ir-serialize-types.h +++ b/source/slang/slang-ir-serialize-types.h @@ -346,16 +346,21 @@ SLANG_FORCE_INLINE int IRSerialData::getOperands(const Inst& inst, const InstInd } } -// Replace first char with 's' -#define SLANG_MAKE_COMPRESSED_FOUR_CC(fourCc) SLANG_FOUR_CC_REPLACE_FIRST_CHAR(fourCc, 's') - -struct IRSerialBinary +// For types/FourCC that work for serializing in general (not just IR). Really this should be placed in some other header +struct SerialBinary { - static const FourCC kRiffFourCc = RiffFourCC::kRiff; static const FourCC kSlangModuleListFourCc = SLANG_FOUR_CC('S', 'L', 'm', 'l'); + static const FourCC kEntryPointFourCc = SLANG_FOUR_CC('E', 'P', 'n', 't'); +}; + +// Replace first char with 's' +#define SLANG_MAKE_COMPRESSED_FOUR_CC(fourCc) SLANG_FOUR_CC_REPLACE_FIRST_CHAR(fourCc, 's') + +struct IRSerialBinary +{ static const FourCC kSlangModuleFourCc = SLANG_FOUR_CC('S', 'L', 'm', 'd'); ///< Holds all the slang specific chunks static const FourCC kSlangModuleHeaderFourCc = SLANG_FOUR_CC('S', 'L', 'h', 'd'); @@ -371,7 +376,6 @@ struct IRSerialBinary static const FourCC kCompressedChildRunFourCc = SLANG_MAKE_COMPRESSED_FOUR_CC(kChildRunFourCc); static const FourCC kCompressedExternalOperandsFourCc = SLANG_MAKE_COMPRESSED_FOUR_CC(kExternalOperandsFourCc); - static const FourCC kStringFourCc = SLANG_FOUR_CC('S', 'L', 's', 't'); static const FourCC kUInt32SourceLocFourCc = SLANG_FOUR_CC('S', 'r', 's', '4'); @@ -382,8 +386,6 @@ struct IRSerialBinary static const FourCC kDebugSourceInfoFourCc = SLANG_FOUR_CC('S', 'd', 's', 'o'); static const FourCC kDebugSourceLocRunFourCc = SLANG_FOUR_CC('S', 'd', 's', 'r'); - static const FourCC kEntryPointFourCc = SLANG_FOUR_CC('E', 'P', 'n', 't'); - typedef IRSerialCompressionType CompressionType; struct ModuleHeader diff --git a/source/slang/slang-ir-serialize.cpp b/source/slang/slang-ir-serialize.cpp index 64316d457..1aa6b1d85 100644 --- a/source/slang/slang-ir-serialize.cpp +++ b/source/slang/slang-ir-serialize.cpp @@ -989,30 +989,22 @@ static int _calcFixSourceLoc(const IRSerialData::DebugSourceInfo& info, SourceVi return int(sourceView->getRange().begin.getRaw()) - int(info.m_startSourceLoc); } -// TODO: The following function isn't really part of the IR serialization system, but rather -// a layered "container" format, and as such probably belongs in a higher-level system that -// simply calls into the `IRSerialReader` rather than being part of it... -// -/* static */Result IRSerialReader::readStreamModules(Stream* stream, Session* session, SourceManager* sourceManager, List<RefPtr<IRModule>>& outModules, List<FrontEndCompileRequest::ExtraEntryPointInfo>& outEntryPoints) +/* static */Result IRSerialReader::readContainerModules(RiffContainer* container, Session* session, SourceManager* sourceManager, List<RefPtr<IRModule>>& outModules, List<FrontEndCompileRequest::ExtraEntryPointInfo>& outEntryPoints) { - // Load up the module - RiffContainer container; - SLANG_RETURN_ON_FAIL(RiffUtil::read(stream, container)); - List<RiffContainer::ListChunk*> moduleChunks; List<RiffContainer::DataChunk*> entryPointChunks; // First try to find a list { - RiffContainer::ListChunk* listChunk = container.getRoot()->findListRec(IRSerialBinary::kSlangModuleListFourCc); + RiffContainer::ListChunk* listChunk = container->getRoot()->findListRec(SerialBinary::kSlangModuleListFourCc); if (listChunk) { listChunk->findContained(IRSerialBinary::kSlangModuleFourCc, moduleChunks); - listChunk->findContained(IRSerialBinary::kEntryPointFourCc, entryPointChunks); + listChunk->findContained(SerialBinary::kEntryPointFourCc, entryPointChunks); } else { // Maybe its just a single module - RiffContainer::ListChunk* moduleChunk = container.getRoot()->findListRec(IRSerialBinary::kSlangModuleFourCc); + RiffContainer::ListChunk* moduleChunk = container->getRoot()->findListRec(IRSerialBinary::kSlangModuleFourCc); if (!moduleChunk) { // Couldn't find any modules @@ -1037,7 +1029,7 @@ static int _calcFixSourceLoc(const IRSerialData::DebugSourceInfo& info, SourceVi outModules.add(irModule); } - for( auto entryPointChunk : entryPointChunks ) + for (auto entryPointChunk : entryPointChunks) { auto reader = entryPointChunk->asReadHelper(); @@ -1046,8 +1038,8 @@ static int _calcFixSourceLoc(const IRSerialData::DebugSourceInfo& info, SourceVi uint32_t length = 0; reader.read(length); - char* begin = (char*) reader.getData(); - reader.skip(length+1); + char* begin = (char*)reader.getData(); + reader.skip(length + 1); return UnownedStringSlice(begin, begin + length); }; @@ -1064,6 +1056,21 @@ static int _calcFixSourceLoc(const IRSerialData::DebugSourceInfo& info, SourceVi return SLANG_OK; } + +// TODO: The following function isn't really part of the IR serialization system, but rather +// a layered "container" format, and as such probably belongs in a higher-level system that +// simply calls into the `IRSerialReader` rather than being part of it... +// +/* static */Result IRSerialReader::readStreamModules(Stream* stream, Session* session, SourceManager* sourceManager, List<RefPtr<IRModule>>& outModules, List<FrontEndCompileRequest::ExtraEntryPointInfo>& outEntryPoints) +{ + // Load up the module + RiffContainer container; + SLANG_RETURN_ON_FAIL(RiffUtil::read(stream, container)); + + SLANG_RETURN_ON_FAIL(readContainerModules(&container, session, sourceManager, outModules, outEntryPoints)); + return SLANG_OK; +} + /* static */Result IRSerialReader::read(const IRSerialData& data, Session* session, SourceManager* sourceManager, RefPtr<IRModule>& moduleOut) { typedef Ser::Inst::PayloadType PayloadType; diff --git a/source/slang/slang-ir-serialize.h b/source/slang/slang-ir-serialize.h index cec672b24..ee99c08dc 100644 --- a/source/slang/slang-ir-serialize.h +++ b/source/slang/slang-ir-serialize.h @@ -125,6 +125,9 @@ struct IRSerialReader /// Read potentially multiple modules from a stream static Result readStreamModules(Stream* stream, Session* session, SourceManager* manager, List<RefPtr<IRModule>>& outModules, List<FrontEndCompileRequest::ExtraEntryPointInfo>& outEntryPoints); + /// Read potentially multiple modules from a stream + static Result readContainerModules(RiffContainer* container, Session* session, SourceManager* manager, List<RefPtr<IRModule>>& outModules, List<FrontEndCompileRequest::ExtraEntryPointInfo>& outEntryPoints); + /// Read a stream to fill in dataOut IRSerialData static Result readContainer(RiffContainer::ListChunk* module, IRSerialData* outData); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 074c71f21..c78ca0a4d 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -2993,10 +2993,23 @@ namespace Slang { SlangResult _addLibraryReference(EndToEndCompileRequest* req, Stream* stream) { + // Load up the module + RiffContainer container; + SLANG_RETURN_ON_FAIL(RiffUtil::read(stream, container)); + + List<RefPtr<Module>> modules; + + if (SLANG_FAILED(ASTSerialReader::readContainerModules(&container, req->getLinkage(), modules))) + { + req->getSink()->diagnose(SourceLoc(), Diagnostics::unableToAddReferenceToModuleContainer); + return SLANG_FAIL; + } + // Read all of the contained modules List<RefPtr<IRModule>> irModules; List<FrontEndCompileRequest::ExtraEntryPointInfo> entryPointMangledNames; - if (SLANG_FAILED(IRSerialReader::readStreamModules(stream, req->getSession(), req->getFrontEndReq()->getSourceManager(), irModules, entryPointMangledNames))) + + if (SLANG_FAILED(IRSerialReader::readContainerModules(&container, req->getSession(), req->getFrontEndReq()->getSourceManager(), irModules, entryPointMangledNames))) { req->getSink()->diagnose(SourceLoc(), Diagnostics::unableToAddReferenceToModuleContainer); return SLANG_FAIL; |
