summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2020-10-29 11:45:56 -0400
committerGitHub <noreply@github.com>2020-10-29 08:45:56 -0700
commit494e09af2cebafa34db49dc1f60afd43aebed619 (patch)
treeb3985b21d4470415a3ad1a6183836528a971ca54 /source
parent1d7a7f23874151372f2792e7307f50c54dae877f (diff)
Handling imported/exporting symbols from serialized modules (#1589)
* #include an absolute path didn't work - because paths were taken to always be relative. * Fix handling of access modifiers inside type definition. * Fix access problem for AST node. Make dumping produce a single function with switch, to potentially make available without Dump specific access. * WIP on serialization design doc. * Remove project references to previously generated files. * More docs on serialization design. * Improve serialization documentation. Remove unused function from IRSerialReader. * Small fixes around naming. Remove long comment from slang-serialize.h - as covered in serialization.md * Remove long comment in slang-serialize.h as covered in serialization.md * More information about doing replacements on read for AST and problems surrounding. * Typo fix. * Spelling fixes. * Value serialize. * Value types with inheritence. * Use value reflection serial conversion for more AST types * Use automatic serialization on more of AST. * Get the types via decltype, simplifies what the extractor has to do. * Update the serialization.md for the value serialization. * Small doc improvements. * Update project. * Remove ImportExternalDecl type Added addImportSymbol and ImportSymbol type Fixed bug in container which meant it wouldn't read back AST module * Because of change of how imports and handled, store objects as SerialPointers. * First pass symbol lookup from mangled names. * Cache current module looked up from mangled name. * Fix SourceLoc bug. Improve comments. * Added diagnostic on mangled symbol not being found * Fix typo. Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-decl.h12
-rw-r--r--source/slang/slang-check-decl.cpp11
-rwxr-xr-xsource/slang/slang-compiler.h15
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-mangled-lexer.cpp55
-rw-r--r--source/slang/slang-mangled-lexer.h32
-rw-r--r--source/slang/slang-serialize-ast.cpp6
-rw-r--r--source/slang/slang-serialize-container.cpp98
-rw-r--r--source/slang/slang-serialize-container.h2
-rw-r--r--source/slang/slang-serialize-factory.cpp13
-rw-r--r--source/slang/slang-serialize.cpp104
-rw-r--r--source/slang/slang-serialize.h47
-rw-r--r--source/slang/slang.cpp106
13 files changed, 390 insertions, 113 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 4210318dc..2bf4c488b 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -481,16 +481,4 @@ class AttributeDecl : public ContainerDecl
SyntaxClass<NodeBase> syntaxClass;
};
-// Import a Declaration that has been defined 'externally' and referenced in this module.
-// Includes the managed name of the declaration.
-// This declaration can be added when a module is serialized to any declaration that is
-// external to the module being serialized, such that it replaces the reference to the
-// declaration in another module.
-class ImportExternalDecl : public DeclBase
-{
- SLANG_AST_CLASS(ImportExternalDecl)
-
- String mangledName;
-};
-
} // namespace Slang
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 0735f4620..d771e689c 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -24,8 +24,7 @@ namespace Slang
{}
void visitDeclGroup(DeclGroup*) {}
- void visitImportExternalDecl(ImportExternalDecl*) {}
-
+
void visitDecl(Decl* decl)
{
checkModifiers(decl);
@@ -42,9 +41,7 @@ namespace Slang
void visitDecl(Decl*) {}
void visitDeclGroup(DeclGroup*) {}
- void visitImportExternalDecl(ImportExternalDecl*) {}
-
-
+
void checkVarDeclCommon(VarDeclBase* varDecl);
void visitVarDecl(VarDecl* varDecl)
@@ -113,7 +110,6 @@ namespace Slang
void visitDecl(Decl*) {}
void visitDeclGroup(DeclGroup*) {}
- void visitImportExternalDecl(ImportExternalDecl*) {}
#define CASE(TYPE) void visit##TYPE(TYPE* decl) { checkForRedeclaration(decl); }
@@ -135,7 +131,6 @@ namespace Slang
void visitDecl(Decl*) {}
void visitDeclGroup(DeclGroup*) {}
- void visitImportExternalDecl(ImportExternalDecl*) {}
void visitInheritanceDecl(InheritanceDecl* inheritanceDecl);
@@ -170,7 +165,6 @@ namespace Slang
void visitDecl(Decl*) {}
void visitDeclGroup(DeclGroup*) {}
- void visitImportExternalDecl(ImportExternalDecl*) {}
void checkVarDeclCommon(VarDeclBase* varDecl);
@@ -1186,7 +1180,6 @@ namespace Slang
void visitDecl(Decl*) {}
void visitDeclGroup(DeclGroup*) {}
- void visitImportExternalDecl(ImportExternalDecl*) {}
// Any user-defined type may have declared interface conformances,
// which we should check.
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 79ef9df27..56885ab46 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -973,6 +973,10 @@ namespace Slang
List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencyList.getModuleList(); }
List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencyList.getFilePathList(); }
+ /// Given a mangled name finds the exported NodeBase associated with this module.
+ /// If not found returns nullptr.
+ NodeBase* findExportFromMangledName(const UnownedStringSlice& slice);
+
/// Get the ASTBuilder
ASTBuilder* getASTBuilder() { return &m_astBuilder; }
@@ -1004,6 +1008,7 @@ namespace Slang
List<RefPtr<EntryPoint>> const& getEntryPoints() { return m_entryPoints; }
void _addEntryPoint(EntryPoint* entryPoint);
+ void _processFindDeclsExportSymbolsRec(Decl* decl);
protected:
void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE;
@@ -1054,6 +1059,11 @@ namespace Slang
// The builder that owns all of the AST nodes from parsing the source of
// this module.
ASTBuilder m_astBuilder;
+
+ // Holds map of exported mangled names to symbols. m_mangledExportPool maps names to indices,
+ // and m_mangledExportSymbols holds the NodeBase* values for each index.
+ StringSlicePool m_mangledExportPool;
+ List<NodeBase*> m_mangledExportSymbols;
};
typedef Module LoadedModule;
@@ -1234,7 +1244,7 @@ namespace Slang
SlangMatrixLayoutMode mode);
/// Create an initially-empty linkage
- Linkage(Session* session, ASTBuilder* astBuilder);
+ Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage);
/// Dtor
~Linkage();
@@ -2137,6 +2147,9 @@ namespace Slang
/// Get the prelude associated with the language
const String& getPreludeForLanguage(SourceLanguage language) { return m_languagePreludes[int(language)]; }
+ /// Get the built in linkage -> handy to get the stdlibs from
+ Linkage* getBuiltinLinkage() const { return m_builtinLinkage; }
+
void init();
void addBuiltinSource(
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index d7f561537..9b4d55c14 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -397,6 +397,8 @@ DIAGNOSTIC(39999, Warning, integerLiteralTruncated, "integer literal '$0' too la
DIAGNOSTIC(39999, Warning, floatLiteralUnrepresentable, "$0 literal '$1' unrepresentable, converted to '$2'")
DIAGNOSTIC(39999, Warning, floatLiteralTooSmall, "'$1' is smaller than the smallest representable value for type $0, converted to '$2'")
+DIAGNOSTIC(39999, Error, unableToFindSymbolInModule, "unable to find the mangled symbol '$0' in module '$1'")
+
// 38xxx
DIAGNOSTIC(38000, Error, entryPointFunctionNotFound, "no function found matching entry point name '$0'")
diff --git a/source/slang/slang-mangled-lexer.cpp b/source/slang/slang-mangled-lexer.cpp
index f1f5ec903..237f9f2a5 100644
--- a/source/slang/slang-mangled-lexer.cpp
+++ b/source/slang/slang-mangled-lexer.cpp
@@ -7,13 +7,13 @@ namespace Slang {
UInt MangledLexer::readCount()
{
- int c = _peek();
- if (!_isDigit((char)c))
+ int c = peekChar();
+ if (!CharUtil::isDigit((char)c))
{
SLANG_UNEXPECTED("bad name mangling");
UNREACHABLE_RETURN(0);
}
- _next();
+ nextChar();
if (c == '0')
return 0;
@@ -22,25 +22,25 @@ UInt MangledLexer::readCount()
for (;;)
{
count = count * 10 + c - '0';
- c = _peek();
- if (!_isDigit((char)c))
+ c = peekChar();
+ if (!CharUtil::isDigit((char)c))
return count;
- _next();
+ nextChar();
}
}
void MangledLexer::readGenericParam()
{
- switch (_peek())
+ switch (peekChar())
{
case 'T':
case 'C':
- _next();
+ nextChar();
break;
case 'v':
- _next();
+ nextChar();
readType();
break;
@@ -62,7 +62,7 @@ void MangledLexer::readGenericParams()
void MangledLexer::readType()
{
- int c = _peek();
+ int c = peekChar();
switch (c)
{
case 'V':
@@ -73,11 +73,11 @@ void MangledLexer::readType()
case 'h':
case 'f':
case 'd':
- _next();
+ nextChar();
break;
case 'v':
- _next();
+ nextChar();
readSimpleIntVal();
readType();
break;
@@ -90,15 +90,15 @@ void MangledLexer::readType()
void MangledLexer::readVal()
{
- switch (_peek())
+ switch (peekChar())
{
case 'k':
- _next();
+ nextChar();
readCount();
break;
case 'K':
- _next();
+ nextChar();
readRawStringSegment();
break;
@@ -124,7 +124,7 @@ UnownedStringSlice MangledLexer::readSimpleName()
UnownedStringSlice result;
for (;;)
{
- int c = _peek();
+ int c = peekChar();
if (c == 'g')
{
@@ -142,7 +142,7 @@ UnownedStringSlice MangledLexer::readSimpleName()
continue;
}
- if (!_isDigit((char)c))
+ if (!CharUtil::isDigit((char)c))
return result;
// Read the length part
@@ -181,4 +181,25 @@ UInt MangledLexer::readParamCount()
return count;
}
+/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MangledNameParser !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
+
+/* static */SlangResult MangledNameParser::parseModuleName(const UnownedStringSlice& in, UnownedStringSlice& outModuleName)
+{
+ MangledLexer lexer(in);
+
+ if (lexer.peekChar() == 'T')
+ {
+ lexer.nextChar();
+ }
+
+ UnownedStringSlice name = lexer.readRawStringSegment();
+ if (name.getLength() == 0)
+ {
+ return SLANG_FAIL;
+ }
+
+ outModuleName = name;
+ return SLANG_OK;
+}
+
} // namespace Slang
diff --git a/source/slang/slang-mangled-lexer.h b/source/slang/slang-mangled-lexer.h
index 6c8060cfd..7d096e45e 100644
--- a/source/slang/slang-mangled-lexer.h
+++ b/source/slang/slang-mangled-lexer.h
@@ -3,6 +3,7 @@
#define SLANG_MANGLED_LEXER_H_INCLUDED
#include "../core/slang-basic.h"
+#include "../core/slang-char-util.h"
#include "slang-compiler.h"
@@ -41,6 +42,12 @@ public:
UInt readParamCount();
+ /// Returns the character at the current position
+ char peekChar() { return *m_cursor; }
+ // Returns the current character and moves to next character.
+ char nextChar() { return *m_cursor++; }
+
+
/// Ctor
SLANG_FORCE_INLINE MangledLexer(const UnownedStringSlice& slice);
@@ -50,13 +57,6 @@ private:
// to strip off the main prefix
void _start() { _expect("_S"); }
- static bool _isDigit(char c) { return (c >= '0') && (c <= '9'); }
-
- /// Returns the character at the current position
- char _peek() { return *m_cursor; }
- // Returns the current character and moves to next character.
- char _next() { return *m_cursor++; }
-
SLANG_INLINE void _expect(char c);
void _expect(char const* str)
@@ -82,10 +82,10 @@ SLANG_FORCE_INLINE MangledLexer::MangledLexer(const UnownedStringSlice& slice)
// ---------------------------------------------------------------------------
SLANG_INLINE void MangledLexer::readSimpleIntVal()
{
- int c = _peek();
- if (_isDigit((char)c))
+ int c = peekChar();
+ if (CharUtil::isDigit((char)c))
{
- _next();
+ nextChar();
}
else
{
@@ -110,9 +110,9 @@ SLANG_INLINE void MangledLexer::readExtensionSpec()
// ---------------------------------------------------------------------------
SLANG_INLINE void MangledLexer::_expect(char c)
{
- if (_peek() == c)
+ if (peekChar() == c)
{
- _next();
+ nextChar();
}
else
{
@@ -121,5 +121,13 @@ SLANG_INLINE void MangledLexer::_expect(char c)
}
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MangledNameParser !!!!!!!!!!!!!!!!!!!!!!!!!!
+
+struct MangledNameParser
+{
+ /// Tries to extract the module name from this mangled name.
+ static SlangResult parseModuleName(const UnownedStringSlice& in, UnownedStringSlice& outModuleName);
+};
+
}
#endif
diff --git a/source/slang/slang-serialize-ast.cpp b/source/slang/slang-serialize-ast.cpp
index 0e8acc3b3..bbd8237d2 100644
--- a/source/slang/slang-serialize-ast.cpp
+++ b/source/slang/slang-serialize-ast.cpp
@@ -84,8 +84,6 @@ struct ASTFieldAccess
NamePool namePool;
namePool.setRootNamePool(rootNamePool);
- SerialReader reader(classes, nullptr);
-
ASTBuilder builder(sharedASTBuilder, "Serialize Check");
DefaultSerialObjectFactory objectFactory(&builder);
@@ -96,7 +94,7 @@ struct ASTFieldAccess
const List<SerialInfo::Entry*>& writtenEntries = writer.getEntries();
List<const SerialInfo::Entry*> readEntries;
- SlangResult res = reader.loadEntries(contents.getBuffer(), contents.getCount(), readEntries);
+ SlangResult res = SerialReader::loadEntries(contents.getBuffer(), contents.getCount(), classes, readEntries);
SLANG_UNUSED(res);
SLANG_ASSERT(writtenEntries.getCount() == readEntries.getCount());
@@ -117,7 +115,9 @@ struct ASTFieldAccess
}
+ SerialReader reader(classes, nullptr);
{
+
SlangResult res = reader.load(contents.getBuffer(), contents.getCount(), &namePool);
SLANG_UNUSED(res);
}
diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp
index 7dd2c15a1..8c4edb9a0 100644
--- a/source/slang/slang-serialize-container.cpp
+++ b/source/slang/slang-serialize-container.cpp
@@ -12,6 +12,8 @@
#include "slang-serialize-source-loc.h"
#include "slang-serialize-factory.h"
+#include "slang-mangled-lexer.h"
+
namespace Slang {
/* static */SlangResult SerialContainerUtil::requestToData(EndToEndCompileRequest* request, const WriteOptions& options, SerialContainerData& out)
@@ -142,6 +144,9 @@ namespace Slang {
{
if (ModuleDecl* moduleDecl = as<ModuleDecl>(module.astRootNode))
{
+ // Put in AST module
+ RiffContainer::ScopeChunk scopeASTModule(container, RiffContainer::Chunk::Kind::List, ASTSerialBinary::kSlangASTModuleFourCC);
+
if (!serialClasses)
{
SLANG_RETURN_ON_FAIL(SerialClassesUtil::create(serialClasses));
@@ -155,8 +160,9 @@ namespace Slang {
// Add the module and everything that isn't filtered out in the filter.
writer.addPointer(moduleDecl);
+
// We can now serialize it into the riff container.
- SLANG_RETURN_ON_FAIL(writer.writeIntoContainer(ASTSerialBinary::kSlangASTModuleFourCC, container));
+ SLANG_RETURN_ON_FAIL(writer.writeIntoContainer(ASTSerialBinary::kSlangASTModuleDataFourCC, container));
}
}
}
@@ -312,9 +318,97 @@ namespace Slang {
SerialReader reader(serialClasses, &objectFactory);
+ // Sets up the entry table - one entry for each 'object'.
+ // No native objects are constructed. No objects are deserialized.
+ SLANG_RETURN_ON_FAIL(reader.loadEntries((const uint8_t*)astData->getPayload(), astData->getSize()));
+
+ // Construct a native object for each table entry (where appropriate).
+ // Note that this *doesn't* set all object pointers - some are special cased and created on demand (strings)
+ // and imported symbols will have their object pointers unset (they are resolved in next step)
+ SLANG_RETURN_ON_FAIL(reader.constructObjects(options.namePool));
+
+ // Resolve external references if the linkage is specified
+ if (options.linkage)
+ {
+ const auto& entries = reader.getEntries();
+ auto& objects = reader.getObjects();
+ const Index entriesCount = entries.getCount();
+
+ String currentModuleName;
+ Module* currentModule = nullptr;
+
+ // Index from 1 (0 is null)
+ for (Index i = 1; i < entriesCount; ++i)
+ {
+ const SerialInfo::Entry* entry = entries[i];
+ if (entry->typeKind == SerialTypeKind::ImportSymbol)
+ {
+ UnownedStringSlice mangledName = reader.getStringSlice(SerialIndex(i));
+
+ UnownedStringSlice moduleName;
+ SLANG_RETURN_ON_FAIL(MangledNameParser::parseModuleName(mangledName, moduleName));
+
+ // If we already have looked up this module and it has the same name just use what we have
+ Module* readModule = nullptr;
+ if (currentModule && moduleName == currentModuleName.getUnownedSlice())
+ {
+ readModule = currentModule;
+ }
+ else
+ {
+ // The modules are loaded on the linkage.
+ Linkage* linkage = options.linkage;
+
+ NamePool* namePool = linkage->getNamePool();
+ Name* moduleNameName = namePool->getName(moduleName);
+
+ readModule = linkage->findOrImportModule(moduleNameName, SourceLoc::fromRaw(0), options.sink);
+ if (!readModule)
+ {
+ return SLANG_FAIL;
+ }
+
+ // Set the current module and name
+ currentModule = readModule;
+ currentModuleName = moduleName;
+ }
+
+ // Look up the symbol
+ NodeBase* nodeBase = readModule->findExportFromMangledName(mangledName);
+
+ if (!nodeBase)
+ {
+ if (options.sink)
+ {
+ options.sink->diagnose(SourceLoc::fromRaw(0), Diagnostics::unableToFindSymbolInModule, mangledName, moduleName);
+ }
+
+ // If didn't find the export then we are done
+ return SLANG_FAIL;
+ }
+
+ // set the result
+ objects[i] = nodeBase;
+ }
+ }
+ }
+
+ // Set the sourceLocReader before doing de-serialize, such can lookup the remapped sourceLocs
reader.getExtraObjects().set(sourceLocReader);
- SLANG_RETURN_ON_FAIL(reader.load((const uint8_t*)astData->getPayload(), astData->getSize(), options.namePool));
+ // TODO(JS):
+ // If modules can have more complicated relationships (like a two modules can refer to symbols
+ // from each other), then we can make this work by
+ // 1) deserialize *without* the external symbols being set up
+ // 2) calculate the symbols
+ // 3) deserialize the other module (in the same way)
+ // 4) run deserializeObjects *again* on each module
+ // This is less efficient than it might be (because deserialize phase is done twice) so if this is necessary
+ // may want a mechanism that *just* does reference lookups.
+ //
+ // For now if we assume a module can only access symbols from another module, and not the reverse.
+ // So we just need to deserialize and we are done
+ SLANG_RETURN_ON_FAIL(reader.deserializeObjects());
// Get the root node. It's at index 1 (0 is the null value).
astRootNode = reader.getPointer(SerialIndex(1)).dynamicCast<NodeBase>();
diff --git a/source/slang/slang-serialize-container.h b/source/slang/slang-serialize-container.h
index c34bebb6c..bc34982e7 100644
--- a/source/slang/slang-serialize-container.h
+++ b/source/slang/slang-serialize-container.h
@@ -89,6 +89,8 @@ struct SerialContainerUtil
SourceManager* sourceManager = nullptr;
NamePool* namePool = nullptr;
SharedASTBuilder* sharedASTBuilder = nullptr;
+ Linkage* linkage = nullptr;
+ DiagnosticSink* sink = nullptr;
};
/// Get the serializable contents of the request as data
diff --git a/source/slang/slang-serialize-factory.cpp b/source/slang/slang-serialize-factory.cpp
index 2bb7b047e..f93fbba69 100644
--- a/source/slang/slang-serialize-factory.cpp
+++ b/source/slang/slang-serialize-factory.cpp
@@ -73,18 +73,11 @@ SerialIndex ModuleSerialFilter::writePointer(SerialWriter* writer, const NodeBas
{
ASTBuilder* astBuilder = m_moduleDecl->module->getASTBuilder();
- // It's a reference to a declaration in another module, so create an ImportExternalDecl.
-
+ // It's a reference to a declaration in another module, so first get the symbol name.
String mangledName = getMangledName(astBuilder, decl);
- ImportExternalDecl* importDecl = astBuilder->create<ImportExternalDecl>();
- importDecl->mangledName = mangledName;
- const SerialIndex index = writer->addPointer(importDecl);
-
- // Set as the index of this
- writer->setPointerIndex(ptr, index);
-
- return index;
+ // Add as an import symbol
+ return writer->addImportSymbol(mangledName);
}
else
{
diff --git a/source/slang/slang-serialize.cpp b/source/slang/slang-serialize.cpp
index 46aae18b2..34680d860 100644
--- a/source/slang/slang-serialize.cpp
+++ b/source/slang/slang-serialize.cpp
@@ -307,7 +307,7 @@ SerialIndex SerialWriter::addPointer(const RefObject* obj)
}
}
-SerialIndex SerialWriter::addString(const UnownedStringSlice& slice)
+SerialIndex SerialWriter::_addStringSlice(SerialTypeKind typeKind, SliceMap& sliceMap, const UnownedStringSlice& slice)
{
typedef ByteEncodeUtil Util;
typedef SerialInfo::StringEntry StringEntry;
@@ -317,9 +317,7 @@ SerialIndex SerialWriter::addString(const UnownedStringSlice& slice)
return SerialIndex(0);
}
- Index newIndex = m_entries.getCount();
-
- Index* indexPtr = m_sliceMap.TryGetValueOrAdd(slice, newIndex);
+ Index* indexPtr = sliceMap.TryGetValue(slice);
if (indexPtr)
{
return SerialIndex(*indexPtr);
@@ -332,7 +330,7 @@ SerialIndex SerialWriter::addString(const UnownedStringSlice& slice)
StringEntry* entry = (StringEntry*)m_arena.allocateUnaligned(SLANG_OFFSET_OF(StringEntry, sizeAndChars) + encodeCount + slice.getLength());
entry->info = SerialInfo::EntryInfo::Alignment1;
- entry->typeKind = SerialTypeKind::String;
+ entry->typeKind = typeKind;
uint8_t* dst = (uint8_t*)(entry->sizeAndChars);
for (int i = 0; i < encodeCount; ++i)
@@ -342,6 +340,13 @@ SerialIndex SerialWriter::addString(const UnownedStringSlice& slice)
memcpy(dst + encodeCount, slice.begin(), slice.getLength());
+ // Make a key that will stay in scope -> it's actually just stored in the arena.
+ // NOTE! without terminating 0
+ UnownedStringSlice keySlice(((const char*)dst) + encodeCount, slice.getLength());
+
+ Index newIndex = m_entries.getCount();
+ sliceMap.Add(keySlice, newIndex);
+
m_entries.add(entry);
return SerialIndex(newIndex);
}
@@ -550,6 +555,7 @@ size_t SerialInfo::Entry::calcSize(SerialClasses* serialClasses) const
{
switch (typeKind)
{
+ case SerialTypeKind::ImportSymbol:
case SerialTypeKind::String:
{
auto entry = static_cast<const StringEntry*>(this);
@@ -632,6 +638,8 @@ SerialPointer SerialReader::getPointer(SerialIndex index)
SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount()));
const Entry* entry = m_entries[Index(index)];
+ const SerialPointer& ptr = m_objects[Index(index)];
+
switch (entry->typeKind)
{
case SerialTypeKind::String:
@@ -640,19 +648,21 @@ SerialPointer SerialReader::getPointer(SerialIndex index)
String string = getString(index);
return SerialPointer(string.getStringRepresentation());
}
- case SerialTypeKind::NodeBase:
- {
- return SerialPointer((NodeBase*)m_objects[Index(index)]);
- }
- case SerialTypeKind::RefObject:
+ case SerialTypeKind::ImportSymbol:
{
- return SerialPointer((RefObject*)m_objects[Index(index)]);
+ if (ptr.m_kind == SerialTypeKind::Unknown)
+ {
+ // TODO(JS):
+ // Could have an error here, because import symbol was not set
+ // For now just return nullptr
+ return SerialPointer();
+ }
+ break;
}
default: break;
}
- SLANG_ASSERT(!"Cannot access as a pointer");
- return SerialPointer();
+ return ptr;
}
String SerialReader::getString(SerialIndex index)
@@ -672,7 +682,7 @@ String SerialReader::getString(SerialIndex index)
return String();
}
- RefObject* obj = (RefObject*)m_objects[Index(index)];
+ RefObject* obj = m_objects[Index(index)].dynamicCast<RefObject>();
if (obj)
{
@@ -714,7 +724,7 @@ Name* SerialReader::getName(SerialIndex index)
return nullptr;
}
- RefObject* obj = (RefObject*)m_objects[Index(index)];
+ RefObject* obj = m_objects[Index(index)].dynamicCast<RefObject>();
if (obj)
{
@@ -728,7 +738,7 @@ Name* SerialReader::getName(SerialIndex index)
SLANG_ASSERT(stringRep);
// I don't need to scope, as scoped in NamePool
- name = m_namePool->getName(String(stringRep));
+ name = m_namePool->getName(String(stringRep));
// Store as name, as can always access the inner string if needed
m_objects[Index(index)] = name;
@@ -749,23 +759,25 @@ UnownedStringSlice SerialReader::getStringSlice(SerialIndex index)
const Entry* entry = m_entries[Index(index)];
// It has to be a string type
- if (entry->typeKind != SerialTypeKind::String)
+ if (entry->typeKind == SerialTypeKind::String ||
+ entry->typeKind == SerialTypeKind::ImportSymbol)
{
- SLANG_ASSERT(!"Not a string");
- return UnownedStringSlice();
- }
+ auto stringEntry = static_cast<const SerialInfo::StringEntry*>(entry);
- auto stringEntry = static_cast<const SerialInfo::StringEntry*>(entry);
+ const uint8_t* src = (const uint8_t*)stringEntry->sizeAndChars;
- const uint8_t* src = (const uint8_t*)stringEntry->sizeAndChars;
+ // Decode the string
+ uint32_t size;
+ int sizeSize = ByteEncodeUtil::decodeLiteUInt32(src, &size);
+ return UnownedStringSlice((const char*)src + sizeSize, size);
+ }
- // Decode the string
- uint32_t size;
- int sizeSize = ByteEncodeUtil::decodeLiteUInt32(src, &size);
- return UnownedStringSlice((const char*)src + sizeSize, size);
+ // Can't be accessed as a slice
+ SLANG_ASSERT(!"Not accessible as a slice");
+ return UnownedStringSlice();
}
-SlangResult SerialReader::loadEntries(const uint8_t* data, size_t dataCount, List<const SerialInfo::Entry*>& outEntries)
+/* static */SlangResult SerialReader::loadEntries(const uint8_t* data, size_t dataCount, SerialClasses* serialClasses, List<const Entry*>& outEntries)
{
// Check the input data is at least aligned to the max alignment (otherwise everything cannot be aligned correctly)
SLANG_ASSERT((size_t(data) & (SerialInfo::MAX_ALIGNMENT - 1)) == 0);
@@ -781,7 +793,7 @@ SlangResult SerialReader::loadEntries(const uint8_t* data, size_t dataCount, Lis
const Entry* entry = (const Entry*)cur;
outEntries.add(entry);
- const size_t entrySize = entry->calcSize(m_classes);
+ const size_t entrySize = entry->calcSize(serialClasses);
cur += entrySize;
// Need to get the next alignment
@@ -794,10 +806,8 @@ SlangResult SerialReader::loadEntries(const uint8_t* data, size_t dataCount, Lis
return SLANG_OK;
}
-SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* namePool)
+SlangResult SerialReader::constructObjects(NamePool* namePool)
{
- SLANG_RETURN_ON_FAIL(loadEntries(data, dataCount, m_entries));
-
m_namePool = namePool;
m_objects.clearAndDeallocate();
@@ -811,6 +821,13 @@ SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool*
switch (entry->typeKind)
{
+ case SerialTypeKind::ImportSymbol:
+ {
+ // We don't construct any object for an imported symbol.
+ // It will be the responsibility of external code to interpet the symbols and *set* the appopriate
+ // objects prior to a call to `deserializeObjects`
+ break;
+ }
case SerialTypeKind::String:
{
// Don't need to construct an object. This is probably a StringRepresentation, or a Name
@@ -826,7 +843,7 @@ SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool*
{
return SLANG_FAIL;
}
- m_objects[i] = obj;
+ m_objects[i].set(entry->typeKind, obj);
break;
}
case SerialTypeKind::Array:
@@ -837,12 +854,18 @@ SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool*
}
}
+ return SLANG_OK;
+}
+
+SlangResult SerialReader::deserializeObjects()
+{
// Deserialize
for (Index i = 1; i < m_entries.getCount(); ++i)
{
const Entry* entry = m_entries[i];
- void* native = m_objects[i];
- if (!native)
+ // First see if there is anything to construct
+ SerialPointer& dstPtr = m_objects[i];
+ if (!dstPtr)
{
continue;
}
@@ -859,7 +882,7 @@ SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool*
}
const uint8_t* src = (const uint8_t*)(objectEntry + 1);
- uint8_t* dst = (uint8_t*)m_objects[i];
+ uint8_t* dst = (uint8_t*)dstPtr.m_ptr;
// It must be constructed
SLANG_ASSERT(dst);
@@ -886,4 +909,15 @@ SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool*
return SLANG_OK;
}
+
+SlangResult SerialReader::load(const uint8_t* data, size_t dataCount, NamePool* namePool)
+{
+ // Load and place entries into entries table
+ SLANG_RETURN_ON_FAIL(loadEntries(data, dataCount));
+ // Construct all of the objects
+ SLANG_RETURN_ON_FAIL(constructObjects(namePool));
+ SLANG_RETURN_ON_FAIL(deserializeObjects());
+ return SLANG_OK;
+}
+
} // namespace Slang
diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h
index 0e7fdd68a..a48a7e216 100644
--- a/source/slang/slang-serialize.h
+++ b/source/slang/slang-serialize.h
@@ -46,6 +46,7 @@ enum class SerialTypeKind : uint8_t
String, ///< String
Array, ///< Array
+ ImportSymbol, ///< Holds the name of the import symbol. Represented in exactly the same way as a string
NodeBase, ///< NodeBase derived
RefObject, ///< RefObject derived types
@@ -160,6 +161,12 @@ struct SerialPointer
{
}
+ /// True if the ptr is set
+ SLANG_FORCE_INLINE operator bool() const { return m_ptr != nullptr; }
+
+ /// Directly set pointer/kind
+ void set(SerialTypeKind kind, void* ptr) { m_kind = kind; m_ptr = ptr; }
+
static SerialTypeKind getKind(const RefObject*) { return SerialTypeKind::RefObject; }
static SerialTypeKind getKind(const NodeBase*) { return SerialTypeKind::NodeBase; }
@@ -221,13 +228,23 @@ public:
Name* getName(SerialIndex index);
UnownedStringSlice getStringSlice(SerialIndex index);
- /// Load the entries table (without deserializing anything)
- /// NOTE! data must stay ins scope for outEntries to be valid
- SlangResult loadEntries(const uint8_t* data, size_t dataCount, List<const SerialInfo::Entry*>& outEntries);
+ SlangResult loadEntries(const uint8_t* data, size_t dataCount) { return loadEntries(data, dataCount, m_classes, m_entries); }
+ /// For each entry construct an object. Does *NOT* deserialize them
+ SlangResult constructObjects(NamePool* namePool);
+ /// Entries must be loaded (with loadEntries), and objects constructed (with constructObjects) before deserializing
+ SlangResult deserializeObjects();
/// NOTE! data must stay ins scope when reading takes place
SlangResult load(const uint8_t* data, size_t dataCount, NamePool* namePool);
+ /// Get the entries list
+ const List<const Entry*>& getEntries() const { return m_entries; }
+
+ /// Access the objects list
+ /// NOTE that if a SerialObject holding a RefObject and needs to be kept in scope, add the RefObject* via addScope
+ List<SerialPointer>& getObjects() { return m_objects; }
+ const List<SerialPointer>& getObjects() const { return m_objects; }
+
/// Add an object to be kept in scope
void addScope(const RefObject* obj) { m_scope.add(obj); }
@@ -242,9 +259,14 @@ public:
}
~SerialReader();
+ /// Load the entries table (without deserializing anything)
+ /// NOTE! data must stay ins scope for outEntries to be valid
+ static SlangResult loadEntries(const uint8_t* data, size_t dataCount, SerialClasses* serialClasses, List<const Entry*>& outEntries);
+
protected:
List<const Entry*> m_entries; ///< The entries
- List<void*> m_objects; ///< The constructed objects
+
+ List<SerialPointer> m_objects; ///< The constructed objects
NamePool* m_namePool; ///< Pool names are added to
List<const RefObject*> m_scope; ///< Keeping objects in scope
@@ -305,10 +327,15 @@ public:
template <typename T>
SerialIndex addArray(const T* in, Index count);
- SerialIndex addString(const UnownedStringSlice& slice);
+ /// Add the string
+ SerialIndex addString(const UnownedStringSlice& slice) { return _addStringSlice(SerialTypeKind::String, m_sliceMap, slice); }
SerialIndex addString(const String& in);
SerialIndex addName(const Name* name);
-
+
+ /// Adding import symbols
+ SerialIndex addImportSymbol(const UnownedStringSlice& slice) { return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, slice); }
+ SerialIndex addImportSymbol(const String& string){ return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, string.getUnownedSlice()); }
+
/// Set a the ptr associated with an index.
/// NOTE! That there cannot be a pre-existing setting.
void setPointerIndex(const NodeBase* ptr, SerialIndex index);
@@ -331,6 +358,10 @@ public:
protected:
+ typedef Dictionary<UnownedStringSlice, Index> SliceMap;
+
+ SerialIndex _addStringSlice(SerialTypeKind typeKind, SliceMap& sliceMap, const UnownedStringSlice& slice);
+
SerialIndex _addArray(size_t elementSize, size_t alignment, const void* elements, Index elementCount);
SerialIndex _add(const void* nativePtr, SerialInfo::Entry* entry)
@@ -345,8 +376,10 @@ protected:
Dictionary<const void*, Index> m_ptrMap; // Maps a pointer to an entry index
+
// NOTE! Assumes the content stays in scope!
- Dictionary<UnownedStringSlice, Index> m_sliceMap;
+ SliceMap m_sliceMap;
+ SliceMap m_importSymbolMap;
SerialExtraObjects m_extraObjects; ///< Extra objects
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index ad4e82d4f..90ddf030c 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -141,9 +141,8 @@ void Session::init()
builtinSourceManager.initialize(nullptr, nullptr);
// Built in linkage uses the built in builder
- m_builtinLinkage = new Linkage(this, builtinAstBuilder);
+ m_builtinLinkage = new Linkage(this, builtinAstBuilder, nullptr);
-
// Because the `Session` retains the builtin `Linkage`,
// we need to make sure that the parent pointer inside
// `Linkage` doesn't create a retain cycle.
@@ -207,7 +206,7 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Session::createSession(
slang::ISession** outSession)
{
RefPtr<ASTBuilder> astBuilder(new ASTBuilder(m_sharedASTBuilder, "Session::astBuilder"));
- RefPtr<Linkage> linkage = new Linkage(this, astBuilder);
+ RefPtr<Linkage> linkage = new Linkage(this, astBuilder, getBuiltinLinkage());
Int targetCount = desc.targetCount;
for(Int ii = 0; ii < targetCount; ++ii)
@@ -457,7 +456,7 @@ Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target)
//
-Linkage::Linkage(Session* session, ASTBuilder* astBuilder)
+Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage)
: m_session(session)
, m_retainedSession(session)
, m_sourceManager(&m_defaultSourceManager)
@@ -468,6 +467,15 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder)
m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr);
setFileSystem(nullptr);
+
+ // Copy of the built in linkages modules
+ if (builtinLinkage)
+ {
+ for (const auto& pair : builtinLinkage->mapNameToLoadedModules)
+ {
+ mapNameToLoadedModules.Add(pair.Key, pair.Value);
+ }
+ }
}
ISlangUnknown* Linkage::getInterface(const Guid& guid)
@@ -1381,7 +1389,7 @@ EndToEndCompileRequest::EndToEndCompileRequest(
, m_sink(nullptr)
{
RefPtr<ASTBuilder> astBuilder(new ASTBuilder(session->m_sharedASTBuilder, "EndToEnd::Linkage::astBuilder"));
- m_linkage = new Linkage(session, astBuilder);
+ m_linkage = new Linkage(session, astBuilder, session->getBuiltinLinkage());
init();
}
@@ -1945,6 +1953,7 @@ void FilePathDependencyList::addDependency(Module* module)
Module::Module(Linkage* linkage)
: ComponentType(linkage)
, m_astBuilder(linkage->getASTBuilder()->getSharedASTBuilder(), "Module")
+ , m_mangledExportPool(StringSlicePool::Style::Empty)
{
addModuleDependency(this);
}
@@ -1995,6 +2004,87 @@ void Module::_addEntryPoint(EntryPoint* entryPoint)
m_entryPoints.add(entryPoint);
}
+static bool _canExportDeclSymbol(ASTNodeType type)
+{
+ switch (type)
+ {
+ case ASTNodeType::ModuleDecl:
+ case ASTNodeType::EmptyDecl:
+ case ASTNodeType::NamespaceDecl:
+ {
+ return false;
+ }
+ default: break;
+ }
+
+ return true;
+}
+
+static bool _canRecurseExportSymbol(Decl* decl)
+{
+ if (as<FunctionDeclBase>(decl) ||
+ as<ScopeDecl>(decl))
+ {
+ return false;
+ }
+ return true;
+}
+
+void Module::_processFindDeclsExportSymbolsRec(Decl* decl)
+{
+ if (_canExportDeclSymbol(decl->astNodeType))
+ {
+ // It's a reference to a declaration in another module, so first get the symbol name.
+ String mangledName = getMangledName(getASTBuilder(), decl);
+
+ Index index = Index(m_mangledExportPool.add(mangledName));
+
+ // TODO(JS): It appears that more than one entity might have the same mangled name.
+ // So for now we ignore and just take the first one.
+ if (index == m_mangledExportSymbols.getCount())
+ {
+ m_mangledExportSymbols.add(decl);
+ }
+ }
+
+ if (!_canRecurseExportSymbol(decl))
+ {
+ // We don't need to recurse any further into this
+ return;
+ }
+
+ // process `decl` itself
+ if(auto containerDecl = as<ContainerDecl>(decl))
+ {
+ for (auto child : containerDecl->members)
+ {
+ _processFindDeclsExportSymbolsRec(child);
+ }
+ }
+ else if (auto genericDecl = as<GenericDecl>(decl))
+ {
+ _processFindDeclsExportSymbolsRec(genericDecl->inner);
+ }
+}
+
+NodeBase* Module::findExportFromMangledName(const UnownedStringSlice& slice)
+{
+ // Will be non zero if has been previously attempted
+ if (m_mangledExportSymbols.getCount() == 0)
+ {
+ // Build up the exported mangled name list
+ _processFindDeclsExportSymbolsRec(getModuleDecl());
+
+ // If nothing found, mark that we have tried looking by making m_mangledExportSymbols.getCount() != 0
+ if (m_mangledExportSymbols.getCount() == 0)
+ {
+ m_mangledExportSymbols.add(nullptr);
+ }
+ }
+
+ const Index index = m_mangledExportPool.findIndex(slice);
+ return (index >= 0) ? m_mangledExportSymbols[index] : nullptr;
+}
// ComponentType
@@ -2742,6 +2832,7 @@ void Session::addBuiltinSource(
Name* moduleName = getNamePool()->getName(path);
auto translationUnitIndex = compileRequest->addTranslationUnit(SourceLanguage::Slang, moduleName);
+
compileRequest->addTranslationUnitSourceString(
translationUnitIndex,
path,
@@ -2764,6 +2855,9 @@ void Session::addBuiltinSource(
auto module = compileRequest->translationUnits[translationUnitIndex]->getModule();
auto moduleDecl = module->getModuleDecl();
+ // Put in the loaded module map
+ linkage->mapNameToLoadedModules.Add(moduleName, module);
+
// Add the resulting code to the appropriate scope
if (!scope->containerDecl)
{
@@ -3189,6 +3283,8 @@ SlangResult _addLibraryReference(EndToEndCompileRequest* req, Stream* stream)
options.session = req->getSession();
options.sharedASTBuilder = linkage->getASTBuilder()->getSharedASTBuilder();
options.sourceManager = linkage->getSourceManager();
+ options.linkage = req->getLinkage();
+ options.sink = req->getSink();
SLANG_RETURN_ON_FAIL(SerialContainerUtil::read(&riffContainer, options, containerData));