From 494e09af2cebafa34db49dc1f60afd43aebed619 Mon Sep 17 00:00:00 2001 From: jsmall-nvidia Date: Thu, 29 Oct 2020 11:45:56 -0400 Subject: Handling imported/exporting symbols from serialized modules (#1589) * #include an absolute path didn't work - because paths were taken to always be relative. * Fix handling of access modifiers inside type definition. * Fix access problem for AST node. Make dumping produce a single function with switch, to potentially make available without Dump specific access. * WIP on serialization design doc. * Remove project references to previously generated files. * More docs on serialization design. * Improve serialization documentation. Remove unused function from IRSerialReader. * Small fixes around naming. Remove long comment from slang-serialize.h - as covered in serialization.md * Remove long comment in slang-serialize.h as covered in serialization.md * More information about doing replacements on read for AST and problems surrounding. * Typo fix. * Spelling fixes. * Value serialize. * Value types with inheritence. * Use value reflection serial conversion for more AST types * Use automatic serialization on more of AST. * Get the types via decltype, simplifies what the extractor has to do. * Update the serialization.md for the value serialization. * Small doc improvements. * Update project. * Remove ImportExternalDecl type Added addImportSymbol and ImportSymbol type Fixed bug in container which meant it wouldn't read back AST module * Because of change of how imports and handled, store objects as SerialPointers. * First pass symbol lookup from mangled names. * Cache current module looked up from mangled name. * Fix SourceLoc bug. Improve comments. * Added diagnostic on mangled symbol not being found * Fix typo. Co-authored-by: Tim Foley --- source/slang/slang-ast-decl.h | 12 ---- source/slang/slang-check-decl.cpp | 11 +-- source/slang/slang-compiler.h | 15 +++- source/slang/slang-diagnostic-defs.h | 2 + source/slang/slang-mangled-lexer.cpp | 55 ++++++++++----- source/slang/slang-mangled-lexer.h | 32 +++++---- source/slang/slang-serialize-ast.cpp | 6 +- source/slang/slang-serialize-container.cpp | 98 +++++++++++++++++++++++++- source/slang/slang-serialize-container.h | 2 + source/slang/slang-serialize-factory.cpp | 13 +--- source/slang/slang-serialize.cpp | 104 ++++++++++++++++++---------- source/slang/slang-serialize.h | 47 +++++++++++-- source/slang/slang.cpp | 106 +++++++++++++++++++++++++++-- 13 files changed, 390 insertions(+), 113 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 4210318dc..2bf4c488b 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -481,16 +481,4 @@ class AttributeDecl : public ContainerDecl SyntaxClass syntaxClass; }; -// Import a Declaration that has been defined 'externally' and referenced in this module. -// Includes the managed name of the declaration. -// This declaration can be added when a module is serialized to any declaration that is -// external to the module being serialized, such that it replaces the reference to the -// declaration in another module. -class ImportExternalDecl : public DeclBase -{ - SLANG_AST_CLASS(ImportExternalDecl) - - String mangledName; -}; - } // namespace Slang diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 0735f4620..d771e689c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -24,8 +24,7 @@ namespace Slang {} void visitDeclGroup(DeclGroup*) {} - void visitImportExternalDecl(ImportExternalDecl*) {} - + void visitDecl(Decl* decl) { checkModifiers(decl); @@ -42,9 +41,7 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} - void visitImportExternalDecl(ImportExternalDecl*) {} - - + void checkVarDeclCommon(VarDeclBase* varDecl); void visitVarDecl(VarDecl* varDecl) @@ -113,7 +110,6 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} - void visitImportExternalDecl(ImportExternalDecl*) {} #define CASE(TYPE) void visit##TYPE(TYPE* decl) { checkForRedeclaration(decl); } @@ -135,7 +131,6 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} - void visitImportExternalDecl(ImportExternalDecl*) {} void visitInheritanceDecl(InheritanceDecl* inheritanceDecl); @@ -170,7 +165,6 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} - void visitImportExternalDecl(ImportExternalDecl*) {} void checkVarDeclCommon(VarDeclBase* varDecl); @@ -1186,7 +1180,6 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} - void visitImportExternalDecl(ImportExternalDecl*) {} // Any user-defined type may have declared interface conformances, // which we should check. diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 79ef9df27..56885ab46 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -973,6 +973,10 @@ namespace Slang List const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencyList.getModuleList(); } List const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencyList.getFilePathList(); } + /// Given a mangled name finds the exported NodeBase associated with this module. + /// If not found returns nullptr. + NodeBase* findExportFromMangledName(const UnownedStringSlice& slice); + /// Get the ASTBuilder ASTBuilder* getASTBuilder() { return &m_astBuilder; } @@ -1004,6 +1008,7 @@ namespace Slang List> const& getEntryPoints() { return m_entryPoints; } void _addEntryPoint(EntryPoint* entryPoint); + void _processFindDeclsExportSymbolsRec(Decl* decl); protected: void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; @@ -1054,6 +1059,11 @@ namespace Slang // The builder that owns all of the AST nodes from parsing the source of // this module. ASTBuilder m_astBuilder; + + // Holds map of exported mangled names to symbols. m_mangledExportPool maps names to indices, + // and m_mangledExportSymbols holds the NodeBase* values for each index. + StringSlicePool m_mangledExportPool; + List m_mangledExportSymbols; }; typedef Module LoadedModule; @@ -1234,7 +1244,7 @@ namespace Slang SlangMatrixLayoutMode mode); /// Create an initially-empty linkage - Linkage(Session* session, ASTBuilder* astBuilder); + Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage); /// Dtor ~Linkage(); @@ -2137,6 +2147,9 @@ namespace Slang /// Get the prelude associated with the language const String& getPreludeForLanguage(SourceLanguage language) { return m_languagePreludes[int(language)]; } + /// Get the built in linkage -> handy to get the stdlibs from + Linkage* getBuiltinLinkage() const { return m_builtinLinkage; } + void init(); void addBuiltinSource( diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index d7f561537..9b4d55c14 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -397,6 +397,8 @@ DIAGNOSTIC(39999, Warning, integerLiteralTruncated, "integer literal '$0' too la DIAGNOSTIC(39999, Warning, floatLiteralUnrepresentable, "$0 literal '$1' unrepresentable, converted to '$2'") DIAGNOSTIC(39999, Warning, floatLiteralTooSmall, "'$1' is smaller than the smallest representable value for type $0, converted to '$2'") +DIAGNOSTIC(39999, Error, unableToFindSymbolInModule, "unable to find the mangled symbol '$0' in module '$1'") + // 38xxx DIAGNOSTIC(38000, Error, entryPointFunctionNotFound, "no function found matching entry point name '$0'") diff --git a/source/slang/slang-mangled-lexer.cpp b/source/slang/slang-mangled-lexer.cpp index f1f5ec903..237f9f2a5 100644 --- a/source/slang/slang-mangled-lexer.cpp +++ b/source/slang/slang-mangled-lexer.cpp @@ -7,13 +7,13 @@ namespace Slang { UInt MangledLexer::readCount() { - int c = _peek(); - if (!_isDigit((char)c)) + int c = peekChar(); + if (!CharUtil::isDigit((char)c)) { SLANG_UNEXPECTED("bad name mangling"); UNREACHABLE_RETURN(0); } - _next(); + nextChar(); if (c == '0') return 0; @@ -22,25 +22,25 @@ UInt MangledLexer::readCount() for (;;) { count = count * 10 + c - '0'; - c = _peek(); - if (!_isDigit((char)c)) + c = peekChar(); + if (!CharUtil::isDigit((char)c)) return count; - _next(); + nextChar(); } } void MangledLexer::readGenericParam() { - switch (_peek()) + switch (peekChar()) { case 'T': case 'C': - _next(); + nextChar(); break; case 'v': - _next(); + nextChar(); readType(); break; @@ -62,7 +62,7 @@ void MangledLexer::readGenericParams() void MangledLexer::readType() { - int c = _peek(); + int c = peekChar(); switch (c) { case 'V': @@ -73,11 +73,11 @@ void MangledLexer::readType() case 'h': case 'f': case 'd': - _next(); + nextChar(); break; case 'v': - _next(); + nextChar(); readSimpleIntVal(); readType(); break; @@ -90,15 +90,15 @@ void MangledLexer::readType() void MangledLexer::readVal() { - switch (_peek()) + switch (peekChar()) { case 'k': - _next(); + nextChar(); readCount(); break; case 'K': - _next(); + nextChar(); readRawStringSegment(); break; @@ -124,7 +124,7 @@ UnownedStringSlice MangledLexer::readSimpleName() UnownedStringSlice result; for (;;) { - int c = _peek(); + int c = peekChar(); if (c == 'g') { @@ -142,7 +142,7 @@ UnownedStringSlice MangledLexer::readSimpleName() continue; } - if (!_isDigit((char)c)) + if (!CharUtil::isDigit((char)c)) return result; // Read the length part @@ -181,4 +181,25 @@ UInt MangledLexer::readParamCount() return count; } +/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MangledNameParser !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ + +/* static */SlangResult MangledNameParser::parseModuleName(const UnownedStringSlice& in, UnownedStringSlice& outModuleName) +{ + MangledLexer lexer(in); + + if (lexer.peekChar() == 'T') + { + lexer.nextChar(); + } + + UnownedStringSlice name = lexer.readRawStringSegment(); + if (name.getLength() == 0) + { + return SLANG_FAIL; + } + + outModuleName = name; + return SLANG_OK; +} + } // namespace Slang diff --git a/source/slang/slang-mangled-lexer.h b/source/slang/slang-mangled-lexer.h index 6c8060cfd..7d096e45e 100644 --- a/source/slang/slang-mangled-lexer.h +++ b/source/slang/slang-mangled-lexer.h @@ -3,6 +3,7 @@ #define SLANG_MANGLED_LEXER_H_INCLUDED #include "../core/slang-basic.h" +#include "../core/slang-char-util.h" #include "slang-compiler.h" @@ -41,6 +42,12 @@ public: UInt readParamCount(); + /// Returns the character at the current position + char peekChar() { return *m_cursor; } + // Returns the current character and moves to next character. + char nextChar() { return *m_cursor++; } + + /// Ctor SLANG_FORCE_INLINE MangledLexer(const UnownedStringSlice& slice); @@ -50,13 +57,6 @@ private: // to strip off the main prefix void _start() { _expect("_S"); } - static bool _isDigit(char c) { return (c >= '0') && (c <= '9'); } - - /// Returns the character at the current position - char _peek() { return *m_cursor; } - // Returns the current character and moves to next character. - char _next() { return *m_cursor++; } - SLANG_INLINE void _expect(char c); void _expect(char const* str) @@ -82,10 +82,10 @@ SLANG_FORCE_INLINE MangledLexer::MangledLexer(const UnownedStringSlice& slice) // --------------------------------------------------------------------------- SLANG_INLINE void MangledLexer::readSimpleIntVal() { - int c = _peek(); - if (_isDigit((char)c)) + int c = peekChar(); + if (CharUtil::isDigit((char)c)) { - _next(); + nextChar(); } else { @@ -110,9 +110,9 @@ SLANG_INLINE void MangledLexer::readExtensionSpec() // --------------------------------------------------------------------------- SLANG_INLINE void MangledLexer::_expect(char c) { - if (_peek() == c) + if (peekChar() == c) { - _next(); + nextChar(); } else { @@ -121,5 +121,13 @@ SLANG_INLINE void MangledLexer::_expect(char c) } } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MangledNameParser !!!!!!!!!!!!!!!!!!!!!!!!!! + +struct MangledNameParser +{ + /// Tries to extract the module name from this mangled name. + static SlangResult parseModuleName(const UnownedStringSlice& in, UnownedStringSlice& outModuleName); +}; + } #endif diff --git a/source/slang/slang-serialize-ast.cpp b/source/slang/slang-serialize-ast.cpp index 0e8acc3b3..bbd8237d2 100644 --- a/source/slang/slang-serialize-ast.cpp +++ b/source/slang/slang-serialize-ast.cpp @@ -84,8 +84,6 @@ struct ASTFieldAccess NamePool namePool; namePool.setRootNamePool(rootNamePool); - SerialReader reader(classes, nullptr); - ASTBuilder builder(sharedASTBuilder, "Serialize Check"); DefaultSerialObjectFactory objectFactory(&builder); @@ -96,7 +94,7 @@ struct ASTFieldAccess const List& writtenEntries = writer.getEntries(); List readEntries; - SlangResult res = reader.loadEntries(contents.getBuffer(), contents.getCount(), readEntries); + SlangResult res = SerialReader::loadEntries(contents.getBuffer(), contents.getCount(), classes, readEntries); SLANG_UNUSED(res); SLANG_ASSERT(writtenEntries.getCount() == readEntries.getCount()); @@ -117,7 +115,9 @@ struct ASTFieldAccess } + SerialReader reader(classes, nullptr); { + SlangResult res = reader.load(contents.getBuffer(), contents.getCount(), &namePool); SLANG_UNUSED(res); } diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp index 7dd2c15a1..8c4edb9a0 100644 --- a/source/slang/slang-serialize-container.cpp +++ b/source/slang/slang-serialize-container.cpp @@ -12,6 +12,8 @@ #include "slang-serialize-source-loc.h" #include "slang-serialize-factory.h" +#include "slang-mangled-lexer.h" + namespace Slang { /* static */SlangResult SerialContainerUtil::requestToData(EndToEndCompileRequest* request, const WriteOptions& options, SerialContainerData& out) @@ -142,6 +144,9 @@ namespace Slang { { if (ModuleDecl* moduleDecl = as(module.astRootNode)) { + // Put in AST module + RiffContainer::ScopeChunk scopeASTModule(container, RiffContainer::Chunk::Kind::List, ASTSerialBinary::kSlangASTModuleFourCC); + if (!serialClasses) { SLANG_RETURN_ON_FAIL(SerialClassesUtil::create(serialClasses)); @@ -155,8 +160,9 @@ namespace Slang { // Add the module and everything that isn't filtered out in the filter. writer.addPointer(moduleDecl); + // We can now serialize it into the riff container. - SLANG_RETURN_ON_FAIL(writer.writeIntoContainer(ASTSerialBinary::kSlangASTModuleFourCC, container)); + SLANG_RETURN_ON_FAIL(writer.writeIntoContainer(ASTSerialBinary::kSlangASTModuleDataFourCC, container)); } } } @@ -312,9 +318,97 @@ namespace Slang { SerialReader reader(serialClasses, &objectFactory); + // Sets up the entry table - one entry for each 'object'. + // No native objects are constructed. No objects are deserialized. + SLANG_RETURN_ON_FAIL(reader.loadEntries((const uint8_t*)astData->getPayload(), astData->getSize())); + + // Construct a native object for each table entry (where appropriate). + // Note that this *doesn't* set all object pointers - some are special cased and created on demand (strings) + // and imported symbols will have their object pointers unset (they are resolved in next step) + SLANG_RETURN_ON_FAIL(reader.constructObjects(options.namePool)); + + // Resolve external references if the linkage is specified + if (options.linkage) + { + const auto& entries = reader.getEntries(); + auto& objects = reader.getObjects(); + const Index entriesCount = entries.getCount(); + + String currentModuleName; + Module* currentModule = nullptr; + + // Index from 1 (0 is null) + for (Index i = 1; i < entriesCount; ++i) + { + const SerialInfo::Entry* entry = entries[i]; + if (entry->typeKind == SerialTypeKind::ImportSymbol) + { + UnownedStringSlice mangledName = reader.getStringSlice(SerialIndex(i)); + + UnownedStringSlice moduleName; + SLANG_RETURN_ON_FAIL(MangledNameParser::parseModuleName(mangledName, moduleName)); + + // If we already have looked up this module and it has the same name just use what we have + Module* readModule = nullptr; + if (currentModule && moduleName == currentModuleName.getUnownedSlice()) + { + readModule = currentModule; + } + else + { + // The modules are loaded on the linkage. + Linkage* linkage = options.linkage; + + NamePool* namePool = linkage->getNamePool(); + Name* moduleNameName = namePool->getName(moduleName); + + readModule = linkage->findOrImportModule(moduleNameName, SourceLoc::fromRaw(0), options.sink); + if (!readModule) + { + return SLANG_FAIL; + } + + // Set the current module and name + currentModule = readModule; + currentModuleName = moduleName; + } + + // Look up the symbol + NodeBase* nodeBase = readModule->findExportFromMangledName(mangledName); + + if (!nodeBase) + { + if (options.sink) + { + options.sink->diagnose(SourceLoc::fromRaw(0), Diagnostics::unableToFindSymbolInModule, mangledName, moduleName); + } + + // If didn't find the export then we are done + return SLANG_FAIL; + } + + // set the result + objects[i] = nodeBase; + } + } + } + + // Set the sourceLocReader before doing de-serialize, such can lookup the remapped sourceLocs reader.getExtraObjects().set(sourceLocReader); - SLANG_RETURN_ON_FAIL(reader.load((const uint8_t*)astData->getPayload(), astData->getSize(), options.namePool)); + // TODO(JS): + // If modules can have more complicated relationships (like a two modules can refer to symbols + // from each other), then we can make this work by + // 1) deserialize *without* the external symbols being set up + // 2) calculate the symbols + // 3) deserialize the other module (in the same way) + // 4) run deserializeObjects *again* on each module + // This is less efficient than it might be (because deserialize phase is done twice) so if this is necessary + // may want a mechanism that *just* does reference lookups. + // + // For now if we assume a module can only access symbols from another module, and not the reverse. + // So we just need to deserialize and we are done + SLANG_RETURN_ON_FAIL(reader.deserializeObjects()); // Get the root node. It's at index 1 (0 is the null value). astRootNode = reader.getPointer(SerialIndex(1)).dynamicCast(); diff --git a/source/slang/slang-serialize-container.h b/source/slang/slang-serialize-container.h index c34bebb6c..bc34982e7 100644 --- a/source/slang/slang-serialize-container.h +++ b/source/slang/slang-serialize-container.h @@ -89,6 +89,8 @@ struct SerialContainerUtil SourceManager* sourceManager = nullptr; NamePool* namePool = nullptr; SharedASTBuilder* sharedASTBuilder = nullptr; + Linkage* linkage = nullptr; + DiagnosticSink* sink = nullptr; }; /// Get the serializable contents of the request as data diff --git a/source/slang/slang-serialize-factory.cpp b/source/slang/slang-serialize-factory.cpp index 2bb7b047e..f93fbba69 100644 --- a/source/slang/slang-serialize-factory.cpp +++ b/source/slang/slang-serialize-factory.cpp @@ -73,18 +73,11 @@ SerialIndex ModuleSerialFilter::writePointer(SerialWriter* writer, const NodeBas { ASTBuilder* astBuilder = m_moduleDecl->module->getASTBuilder(); - // It's a reference to a declaration in another module, so create an ImportExternalDecl. - + // It's a reference to a declaration in another module, so first get the symbol name. String mangledName = getMangledName(astBuilder, decl); - ImportExternalDecl* importDecl = astBuilder->create(); - importDecl->mangledName = mangledName; - const SerialIndex index = writer->addPointer(importDecl); - - // Set as the index of this - writer->setPointerIndex(ptr, index); - - return index; + // Add as an import symbol + return writer->addImportSymbol(mangledName); } else { diff --git a/source/slang/slang-serialize.cpp b/source/slang/slang-serialize.cpp index 46aae18b2..34680d860 100644 --- a/source/slang/slang-serialize.cpp +++ b/source/slang/slang-serialize.cpp @@ -307,7 +307,7 @@ SerialIndex SerialWriter::addPointer(const RefObject* obj) } } -SerialIndex SerialWriter::addString(const UnownedStringSlice& slice) +SerialIndex SerialWriter::_addStringSlice(SerialTypeKind typeKind, SliceMap& sliceMap, const UnownedStringSlice& slice) { typedef ByteEncodeUtil Util; typedef SerialInfo::StringEntry StringEntry; @@ -317,9 +317,7 @@ SerialIndex SerialWriter::addString(const UnownedStringSlice& slice) return SerialIndex(0); } - Index newIndex = m_entries.getCount(); - - Index* indexPtr = m_sliceMap.TryGetValueOrAdd(slice, newIndex); + Index* indexPtr = sliceMap.TryGetValue(slice); if (indexPtr) { return SerialIndex(*indexPtr); @@ -332,7 +330,7 @@ SerialIndex SerialWriter::addString(const UnownedStringSlice& slice) StringEntry* entry = (StringEntry*)m_arena.allocateUnaligned(SLANG_OFFSET_OF(StringEntry, sizeAndChars) + encodeCount + slice.getLength()); entry->info = SerialInfo::EntryInfo::Alignment1; - entry->typeKind = SerialTypeKind::String; + entry->typeKind = typeKind; uint8_t* dst = (uint8_t*)(entry->sizeAndChars); for (int i = 0; i < encodeCount; ++i) @@ -342,6 +340,13 @@ SerialIndex SerialWriter::addString(const UnownedStringSlice& slice) memcpy(dst + encodeCount, slice.begin(), slice.getLength()); + // Make a key that will stay in scope -> it's actually just stored in the arena. + // NOTE! without terminating 0 + UnownedStringSlice keySlice(((const char*)dst) + encodeCount, slice.getLength()); + + Index newIndex = m_entries.getCount(); + sliceMap.Add(keySlice, newIndex); + m_entries.add(entry); return SerialIndex(newIndex); } @@ -550,6 +555,7 @@ size_t SerialInfo::Entry::calcSize(SerialClasses* serialClasses) const { switch (typeKind) { + case SerialTypeKind::ImportSymbol: case SerialTypeKind::String: { auto entry = static_cast(this); @@ -632,6 +638,8 @@ SerialPointer SerialReader::getPointer(SerialIndex index) SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); const Entry* entry = m_entries[Index(index)]; + const SerialPointer& ptr = m_objects[Index(index)]; + switch (entry->typeKind) { case SerialTypeKind::String: @@ -640,19 +648,21 @@ SerialPointer SerialReader::getPointer(SerialIndex index) String string = getString(index); return SerialPointer(string.getStringRepresentation()); } - case SerialTypeKind::NodeBase: - { - return SerialPointer((NodeBase*)m_objects[Index(index)]); - } - case SerialTypeKind::RefObject: + case SerialTypeKind::ImportSymbol: { - return SerialPointer((RefObject*)m_objects[Index(index)]); + if (ptr.m_kind == SerialTypeKind::Unknown) + { + // TODO(JS): + // Could have an error here, because import symbol was not set + // For now just return nullptr + return SerialPointer(); + } + break; } default: break; } - SLANG_ASSERT(!"Cannot access as a pointer"); - return SerialPointer(); + return ptr; } String SerialReader::getString(SerialIndex index) @@ -672,7 +682,7 @@ String SerialReader::getString(SerialIndex index) return String(); } - RefObject* obj = (RefObject*)m_objects[Index(index)]; + RefObject* obj = m_objects[Index(index)].dynamicCast(); if (obj) { @@ -714,7 +724,7 @@ Name* SerialReader::getName(SerialIndex index) return nullptr; } - RefObject* obj = (RefObject*)m_objects[Index(index)]; + RefObject* obj = m_objects[Index(index)].dynamicCast(); if (obj) { @@ -728,7 +738,7 @@ Name* SerialReader::getName(SerialIndex index) SLANG_ASSERT(stringRep); // I don't need to scope, as scoped in NamePool - name = m_namePool->getName(String(stringRep)); + name = m_namePool->getName(String(stringRep)); // Store as name, as can always access the inner string if needed m_objects[Index(index)] = name; @@ -749,23 +759,25 @@ UnownedStringSlice SerialReader::getStringSlice(SerialIndex index) const Entry* entry = m_entries[Index(index)]; // It has to be a string type - if (entry->typeKind != SerialTypeKind::String) + if (entry->typeKind == SerialTypeKind::String || + entry->typeKind == SerialTypeKind::ImportSymbol) { - SLANG_ASSERT(!"Not a string"); - return UnownedStringSlice(); - } + auto stringEntry = static_cast(entry); - auto stringEntry = static_cast(entry); + const uint8_t* src = (const uint8_t*)stringEntry->sizeAndChars; - const uint8_t* src = (const uint8_t*)stringEntry->sizeAndChars; + // Decode the string + uint32_t size; + int sizeSize = ByteEncodeUtil::decodeLiteUInt32(src, &size); + return UnownedStringSlice((const char*)src + sizeSize, size); + } - // Decode the string - uint32_t size; - int sizeSize = ByteEncodeUtil::decodeLiteUInt32(src, &size); - return UnownedStringSlice((const char*)src + sizeSize, size); + // Can't be accessed as a slice + SLANG_ASSERT(!"Not accessible as a slice"); + return UnownedStringSlice(); } -SlangResult SerialReader::loadEntries(const uint8_t* data, size_t dataCount, List& outEntries) +/* static */SlangResult SerialReader::loadEntries(const uint8_t* data, size_t dataCount, SerialClasses* serialClasses, List& outEntries) { // Check the input data is at least aligned to the max alignment (otherwise everything cannot be aligned correctly) SLANG_ASSERT((size_t(data) & (SerialInfo::MAX_ALIGNMENT - 1)) == 0); @@ -781,7 +793,7 @@ SlangResult SerialReader::loadEntries(const uint8_t* data, size_t dataCount, Lis const Entry* entry = (const Entry*)cur; outEntries.add(entry); - const size_t entrySize = entry->calcSize(m_classes); + const size_t entrySize = entry->calcSize(serialClasses); cur += entrySize; // Need to get the next alignment @@ -794,10 +806,8 @@ SlangResult SerialReader::loadEntries(const uint8_t* data, size_t dataCount, Lis return SLANG_OK; } -SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* namePool) +SlangResult SerialReader::constructObjects(NamePool* namePool) { - SLANG_RETURN_ON_FAIL(loadEntries(data, dataCount, m_entries)); - m_namePool = namePool; m_objects.clearAndDeallocate(); @@ -811,6 +821,13 @@ SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* switch (entry->typeKind) { + case SerialTypeKind::ImportSymbol: + { + // We don't construct any object for an imported symbol. + // It will be the responsibility of external code to interpet the symbols and *set* the appopriate + // objects prior to a call to `deserializeObjects` + break; + } case SerialTypeKind::String: { // Don't need to construct an object. This is probably a StringRepresentation, or a Name @@ -826,7 +843,7 @@ SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* { return SLANG_FAIL; } - m_objects[i] = obj; + m_objects[i].set(entry->typeKind, obj); break; } case SerialTypeKind::Array: @@ -837,12 +854,18 @@ SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* } } + return SLANG_OK; +} + +SlangResult SerialReader::deserializeObjects() +{ // Deserialize for (Index i = 1; i < m_entries.getCount(); ++i) { const Entry* entry = m_entries[i]; - void* native = m_objects[i]; - if (!native) + // First see if there is anything to construct + SerialPointer& dstPtr = m_objects[i]; + if (!dstPtr) { continue; } @@ -859,7 +882,7 @@ SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* } const uint8_t* src = (const uint8_t*)(objectEntry + 1); - uint8_t* dst = (uint8_t*)m_objects[i]; + uint8_t* dst = (uint8_t*)dstPtr.m_ptr; // It must be constructed SLANG_ASSERT(dst); @@ -886,4 +909,15 @@ SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* return SLANG_OK; } + +SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* namePool) +{ + // Load and place entries into entries table + SLANG_RETURN_ON_FAIL(loadEntries(data, dataCount)); + // Construct all of the objects + SLANG_RETURN_ON_FAIL(constructObjects(namePool)); + SLANG_RETURN_ON_FAIL(deserializeObjects()); + return SLANG_OK; +} + } // namespace Slang diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h index 0e7fdd68a..a48a7e216 100644 --- a/source/slang/slang-serialize.h +++ b/source/slang/slang-serialize.h @@ -46,6 +46,7 @@ enum class SerialTypeKind : uint8_t String, ///< String Array, ///< Array + ImportSymbol, ///< Holds the name of the import symbol. Represented in exactly the same way as a string NodeBase, ///< NodeBase derived RefObject, ///< RefObject derived types @@ -160,6 +161,12 @@ struct SerialPointer { } + /// True if the ptr is set + SLANG_FORCE_INLINE operator bool() const { return m_ptr != nullptr; } + + /// Directly set pointer/kind + void set(SerialTypeKind kind, void* ptr) { m_kind = kind; m_ptr = ptr; } + static SerialTypeKind getKind(const RefObject*) { return SerialTypeKind::RefObject; } static SerialTypeKind getKind(const NodeBase*) { return SerialTypeKind::NodeBase; } @@ -221,13 +228,23 @@ public: Name* getName(SerialIndex index); UnownedStringSlice getStringSlice(SerialIndex index); - /// Load the entries table (without deserializing anything) - /// NOTE! data must stay ins scope for outEntries to be valid - SlangResult loadEntries(const uint8_t* data, size_t dataCount, List& outEntries); + SlangResult loadEntries(const uint8_t* data, size_t dataCount) { return loadEntries(data, dataCount, m_classes, m_entries); } + /// For each entry construct an object. Does *NOT* deserialize them + SlangResult constructObjects(NamePool* namePool); + /// Entries must be loaded (with loadEntries), and objects constructed (with constructObjects) before deserializing + SlangResult deserializeObjects(); /// NOTE! data must stay ins scope when reading takes place SlangResult load(const uint8_t* data, size_t dataCount, NamePool* namePool); + /// Get the entries list + const List& getEntries() const { return m_entries; } + + /// Access the objects list + /// NOTE that if a SerialObject holding a RefObject and needs to be kept in scope, add the RefObject* via addScope + List& getObjects() { return m_objects; } + const List& getObjects() const { return m_objects; } + /// Add an object to be kept in scope void addScope(const RefObject* obj) { m_scope.add(obj); } @@ -242,9 +259,14 @@ public: } ~SerialReader(); + /// Load the entries table (without deserializing anything) + /// NOTE! data must stay ins scope for outEntries to be valid + static SlangResult loadEntries(const uint8_t* data, size_t dataCount, SerialClasses* serialClasses, List& outEntries); + protected: List m_entries; ///< The entries - List m_objects; ///< The constructed objects + + List m_objects; ///< The constructed objects NamePool* m_namePool; ///< Pool names are added to List m_scope; ///< Keeping objects in scope @@ -305,10 +327,15 @@ public: template SerialIndex addArray(const T* in, Index count); - SerialIndex addString(const UnownedStringSlice& slice); + /// Add the string + SerialIndex addString(const UnownedStringSlice& slice) { return _addStringSlice(SerialTypeKind::String, m_sliceMap, slice); } SerialIndex addString(const String& in); SerialIndex addName(const Name* name); - + + /// Adding import symbols + SerialIndex addImportSymbol(const UnownedStringSlice& slice) { return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, slice); } + SerialIndex addImportSymbol(const String& string){ return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, string.getUnownedSlice()); } + /// Set a the ptr associated with an index. /// NOTE! That there cannot be a pre-existing setting. void setPointerIndex(const NodeBase* ptr, SerialIndex index); @@ -331,6 +358,10 @@ public: protected: + typedef Dictionary SliceMap; + + SerialIndex _addStringSlice(SerialTypeKind typeKind, SliceMap& sliceMap, const UnownedStringSlice& slice); + SerialIndex _addArray(size_t elementSize, size_t alignment, const void* elements, Index elementCount); SerialIndex _add(const void* nativePtr, SerialInfo::Entry* entry) @@ -345,8 +376,10 @@ protected: Dictionary m_ptrMap; // Maps a pointer to an entry index + // NOTE! Assumes the content stays in scope! - Dictionary m_sliceMap; + SliceMap m_sliceMap; + SliceMap m_importSymbolMap; SerialExtraObjects m_extraObjects; ///< Extra objects diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index ad4e82d4f..90ddf030c 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -141,9 +141,8 @@ void Session::init() builtinSourceManager.initialize(nullptr, nullptr); // Built in linkage uses the built in builder - m_builtinLinkage = new Linkage(this, builtinAstBuilder); + m_builtinLinkage = new Linkage(this, builtinAstBuilder, nullptr); - // Because the `Session` retains the builtin `Linkage`, // we need to make sure that the parent pointer inside // `Linkage` doesn't create a retain cycle. @@ -207,7 +206,7 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Session::createSession( slang::ISession** outSession) { RefPtr astBuilder(new ASTBuilder(m_sharedASTBuilder, "Session::astBuilder")); - RefPtr linkage = new Linkage(this, astBuilder); + RefPtr linkage = new Linkage(this, astBuilder, getBuiltinLinkage()); Int targetCount = desc.targetCount; for(Int ii = 0; ii < targetCount; ++ii) @@ -457,7 +456,7 @@ Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target) // -Linkage::Linkage(Session* session, ASTBuilder* astBuilder) +Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage) : m_session(session) , m_retainedSession(session) , m_sourceManager(&m_defaultSourceManager) @@ -468,6 +467,15 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder) m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr); setFileSystem(nullptr); + + // Copy of the built in linkages modules + if (builtinLinkage) + { + for (const auto& pair : builtinLinkage->mapNameToLoadedModules) + { + mapNameToLoadedModules.Add(pair.Key, pair.Value); + } + } } ISlangUnknown* Linkage::getInterface(const Guid& guid) @@ -1381,7 +1389,7 @@ EndToEndCompileRequest::EndToEndCompileRequest( , m_sink(nullptr) { RefPtr astBuilder(new ASTBuilder(session->m_sharedASTBuilder, "EndToEnd::Linkage::astBuilder")); - m_linkage = new Linkage(session, astBuilder); + m_linkage = new Linkage(session, astBuilder, session->getBuiltinLinkage()); init(); } @@ -1945,6 +1953,7 @@ void FilePathDependencyList::addDependency(Module* module) Module::Module(Linkage* linkage) : ComponentType(linkage) , m_astBuilder(linkage->getASTBuilder()->getSharedASTBuilder(), "Module") + , m_mangledExportPool(StringSlicePool::Style::Empty) { addModuleDependency(this); } @@ -1995,6 +2004,87 @@ void Module::_addEntryPoint(EntryPoint* entryPoint) m_entryPoints.add(entryPoint); } +static bool _canExportDeclSymbol(ASTNodeType type) +{ + switch (type) + { + case ASTNodeType::ModuleDecl: + case ASTNodeType::EmptyDecl: + case ASTNodeType::NamespaceDecl: + { + return false; + } + default: break; + } + + return true; +} + +static bool _canRecurseExportSymbol(Decl* decl) +{ + if (as(decl) || + as(decl)) + { + return false; + } + return true; +} + +void Module::_processFindDeclsExportSymbolsRec(Decl* decl) +{ + if (_canExportDeclSymbol(decl->astNodeType)) + { + // It's a reference to a declaration in another module, so first get the symbol name. + String mangledName = getMangledName(getASTBuilder(), decl); + + Index index = Index(m_mangledExportPool.add(mangledName)); + + // TODO(JS): It appears that more than one entity might have the same mangled name. + // So for now we ignore and just take the first one. + if (index == m_mangledExportSymbols.getCount()) + { + m_mangledExportSymbols.add(decl); + } + } + + if (!_canRecurseExportSymbol(decl)) + { + // We don't need to recurse any further into this + return; + } + + // process `decl` itself + if(auto containerDecl = as(decl)) + { + for (auto child : containerDecl->members) + { + _processFindDeclsExportSymbolsRec(child); + } + } + else if (auto genericDecl = as(decl)) + { + _processFindDeclsExportSymbolsRec(genericDecl->inner); + } +} + +NodeBase* Module::findExportFromMangledName(const UnownedStringSlice& slice) +{ + // Will be non zero if has been previously attempted + if (m_mangledExportSymbols.getCount() == 0) + { + // Build up the exported mangled name list + _processFindDeclsExportSymbolsRec(getModuleDecl()); + + // If nothing found, mark that we have tried looking by making m_mangledExportSymbols.getCount() != 0 + if (m_mangledExportSymbols.getCount() == 0) + { + m_mangledExportSymbols.add(nullptr); + } + } + + const Index index = m_mangledExportPool.findIndex(slice); + return (index >= 0) ? m_mangledExportSymbols[index] : nullptr; +} // ComponentType @@ -2742,6 +2832,7 @@ void Session::addBuiltinSource( Name* moduleName = getNamePool()->getName(path); auto translationUnitIndex = compileRequest->addTranslationUnit(SourceLanguage::Slang, moduleName); + compileRequest->addTranslationUnitSourceString( translationUnitIndex, path, @@ -2764,6 +2855,9 @@ void Session::addBuiltinSource( auto module = compileRequest->translationUnits[translationUnitIndex]->getModule(); auto moduleDecl = module->getModuleDecl(); + // Put in the loaded module map + linkage->mapNameToLoadedModules.Add(moduleName, module); + // Add the resulting code to the appropriate scope if (!scope->containerDecl) { @@ -3189,6 +3283,8 @@ SlangResult _addLibraryReference(EndToEndCompileRequest* req, Stream* stream) options.session = req->getSession(); options.sharedASTBuilder = linkage->getASTBuilder()->getSharedASTBuilder(); options.sourceManager = linkage->getSourceManager(); + options.linkage = req->getLinkage(); + options.sink = req->getSink(); SLANG_RETURN_ON_FAIL(SerialContainerUtil::read(&riffContainer, options, containerData)); -- cgit v1.2.3