diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2020-08-31 13:02:55 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-08-31 13:02:55 -0400 |
| commit | 69025ad82238a7402b18d9c566fac1574faef684 (patch) | |
| tree | ab01e4248071f9f597aa04f15742852a7662ed23 /source/slang/slang-ast-serialize.cpp | |
| parent | baa789e0c9109bcb1e717ce4a9953709e7345e55 (diff) | |
AST Serialization in Modules (#1524)
* First pass at filter for AST serial writing.
* Serialization of AST for modules.
* Removed some commented out source.
Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-ast-serialize.cpp')
| -rw-r--r-- | source/slang/slang-ast-serialize.cpp | 251 |
1 files changed, 233 insertions, 18 deletions
diff --git a/source/slang/slang-ast-serialize.cpp b/source/slang/slang-ast-serialize.cpp index 6eaa29968..22d5c3de2 100644 --- a/source/slang/slang-ast-serialize.cpp +++ b/source/slang/slang-ast-serialize.cpp @@ -8,9 +8,12 @@ #include "slang-type-layout.h" #include "slang-ast-dump.h" +#include "slang-mangle.h" #include "slang-ast-support-types.h" +#include "slang-legalize-types.h" + #include "../core/slang-byte-encode-util.h" namespace Slang { @@ -46,6 +49,60 @@ static void _toNativeValue(ASTSerialReader* reader, const SERIAL_TYPE& src, NATI ASTSerialTypeInfo<NATIVE_TYPE>::toNative(reader, &src, &dst); } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ModuleASTSerialFilter !!!!!!!!!!!!!!!!!!!!!!!! + + +ASTSerialIndex ModuleASTSerialFilter::writePointer(ASTSerialWriter* writer, const NodeBase* inPtr) +{ + NodeBase* ptr = const_cast<NodeBase*>(inPtr); + SLANG_ASSERT(ptr); + + if (Decl* decl = as<Decl>(ptr)) + { + ModuleDecl* moduleDecl = findModuleForDecl(decl); + SLANG_ASSERT(moduleDecl); + if (moduleDecl && moduleDecl != m_moduleDecl) + { + ASTBuilder* astBuilder = m_moduleDecl->module->getASTBuilder(); + + // It's a reference to a declaration in another module, so create an ImportExternalDecl. + + String mangledName = getMangledName(astBuilder, decl); + + ImportExternalDecl* importDecl = astBuilder->create<ImportExternalDecl>(); + importDecl->mangledName = mangledName; + const ASTSerialIndex index = writer->writePointer(importDecl); + + // Set as the index of this + writer->setPointerIndex(ptr, index); + + return index; + } + else + { + // Okay... we can just write it out then + return writer->writePointer(ptr); + } + } + + // TODO(JS): What we really want to do here is to ignore bodies functions. + // It's not 100% clear if this is even right though - for example does type inference + // imply the body is needed to say infer a return type? + // Also not clear if statements in other scenarios (if there are others) might need to be kept. + // + // For now we just ignore all stmts + + if (Stmt* stmt = as<Stmt>(ptr)) + { + // + writer->setPointerIndex(stmt, ASTSerialIndex(0)); + return ASTSerialIndex(0); + } + + // For now for everything else just write it + return writer->writePointer(ptr); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Serial <-> Native conversion !!!!!!!!!!!!!!!!!!!!!!!! @@ -1030,28 +1087,20 @@ ASTSerialClasses::ASTSerialClasses(): // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerialWriter !!!!!!!!!!!!!!!!!!!!!!!!!!!! -ASTSerialWriter::ASTSerialWriter(ASTSerialClasses* classes) : +ASTSerialWriter::ASTSerialWriter(ASTSerialClasses* classes, ASTSerialFilter* filter) : m_arena(2048), - m_classes(classes) + m_classes(classes), + m_filter(filter) { // 0 is always the null pointer m_entries.add(nullptr); m_ptrMap.Add(nullptr, 0); } -ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node) +ASTSerialIndex ASTSerialWriter::writePointer(const NodeBase* node) { - // Null is always 0 - if (node == nullptr) - { - return ASTSerialIndex(0); - } - // Look up in the map - Index* indexPtr = m_ptrMap.TryGetValue(node); - if (indexPtr) - { - return ASTSerialIndex(*indexPtr); - } + // This pointer cannot be in the map + SLANG_ASSERT(m_ptrMap.TryGetValue(node) == nullptr); const ASTSerialClass* serialClass = m_classes->getSerialClass(node->astNodeType); @@ -1072,6 +1121,7 @@ ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node) for (Index i = 0; i < serialClass->fieldsCount; ++i) { auto field = serialClass->fields[i]; + // Work out the offsets auto srcField = ((const uint8_t*)node) + field.nativeOffset; auto dstField = serialPayload + field.serialOffset; @@ -1088,6 +1138,35 @@ ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node) return index; } +void ASTSerialWriter::setPointerIndex(const NodeBase* ptr, ASTSerialIndex index) +{ + m_ptrMap.Add(ptr, Index(index)); +} + +ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node) +{ + // Null is always 0 + if (node == nullptr) + { + return ASTSerialIndex(0); + } + // Look up in the map + Index* indexPtr = m_ptrMap.TryGetValue(node); + if (indexPtr) + { + return ASTSerialIndex(*indexPtr); + } + + if (m_filter) + { + return m_filter->writePointer(this, node); + } + else + { + return writePointer(node); + } +} + ASTSerialIndex ASTSerialWriter::addPointer(const RefObject* obj) { // Null is always 0 @@ -1246,6 +1325,8 @@ ASTSerialIndex ASTSerialWriter::_addArray(size_t elementSize, size_t alignment, return ASTSerialIndex(m_entries.getCount() - 1); } +static const uint8_t s_fixBuffer[ASTSerialInfo::MAX_ALIGNMENT]{ 0, }; + SlangResult ASTSerialWriter::write(Stream* stream) { const Int entriesCount = m_entries.getCount(); @@ -1263,8 +1344,6 @@ SlangResult ASTSerialWriter::write(Stream* stream) // knowing that removeLast cannot release memory, means the sentinal must be at the last position. entries[entriesCount] = &sentinal; - - static const uint8_t fixBuffer[ASTSerialInfo::MAX_ALIGNMENT] { 0, }; { size_t offset = 0; @@ -1303,7 +1382,7 @@ SlangResult ASTSerialWriter::write(Stream* stream) // If we needed to fix so that subsequent alignment is right, write out extra bytes here if (alignmentFixSize) { - stream->write(fixBuffer, alignmentFixSize); + stream->write(s_fixBuffer, alignmentFixSize); } } catch (const IOException&) @@ -1320,6 +1399,80 @@ SlangResult ASTSerialWriter::write(Stream* stream) return SLANG_OK; } +SlangResult ASTSerialWriter::writeIntoContainer(RiffContainer* container) +{ + typedef RiffContainer::Chunk Chunk; + typedef RiffContainer::ScopeChunk ScopeChunk; + + // This is the container for the AST Data + ScopeChunk scopeModule(container, Chunk::Kind::List, ASTSerialBinary::kSlangASTModuleFourCC); + { + ScopeChunk scopeData(container, Chunk::Kind::Data, ASTSerialBinary::kSlangASTModuleDataFourCC); + + { + // Sentinal so we don't need special handling for end of list + ASTSerialInfo::Entry sentinal; + sentinal.type = ASTSerialInfo::Type::String; + sentinal.info = ASTSerialInfo::EntryInfo::Alignment1; + + size_t offset = 0; + const Int entriesCount = m_entries.getCount(); + + { + m_entries.add(&sentinal); + m_entries.removeLast(); + // Note strictly required in our impl of List. But by writing this and + // knowing that removeLast cannot release memory, means the sentinal must be at the last position. + m_entries.getBuffer()[entriesCount] = &sentinal; + } + + ASTSerialInfo::Entry*const* entries = m_entries.getBuffer(); + + ASTSerialInfo::Entry* entry = entries[1]; + // We start on 1, because 0 is nullptr and not used for anything + for (Index i = 1; i < entriesCount; ++i) + { + ASTSerialInfo::Entry* next = entries[i + 1]; + + // Before writing we need to store the next alignment + + const size_t nextAlignment = ASTSerialInfo::getAlignment(next->info); + const size_t alignment = ASTSerialInfo::getAlignment(entry->info); + + entry->info = ASTSerialInfo::combineWithNext(entry->info, next->info); + + // Check we are aligned correctly + SLANG_ASSERT((offset & (alignment - 1)) == 0); + + // When we write, we need to make sure it take into account the next alignment + const size_t entrySize = entry->calcSize(m_classes); + + // Work out the fix for next alignment + size_t nextOffset = offset + entrySize; + nextOffset = (nextOffset + nextAlignment - 1) & ~(nextAlignment - 1); + + size_t alignmentFixSize = nextOffset - (offset + entrySize); + + // The fix must be less than max alignment. We require it to be less because we aligned each Entry to + // MAX_ALIGNMENT, and so < MAX_ALIGNMENT is the most extra bytes we can write + SLANG_ASSERT(alignmentFixSize < ASTSerialInfo::MAX_ALIGNMENT); + + container->write(entry, entrySize); + if (alignmentFixSize) + { + container->write(s_fixBuffer, alignmentFixSize); + } + + // Onto next + offset = nextOffset; + entry = next; + } + } + } + + return SLANG_OK; +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerialInfo::Entry !!!!!!!!!!!!!!!!!!!!!!!! size_t ASTSerialInfo::Entry::calcSize(ASTSerialClasses* serialClasses) const @@ -1713,6 +1866,62 @@ SlangResult ASTSerialReader::load(const uint8_t* data, size_t dataCount, ASTBuil } +/* static */Result ASTSerialReader::readContainerModules(RiffContainer* container, Linkage* linkage, List<RefPtr<Module>>& outModules) +{ + List<RiffContainer::ListChunk*> moduleChunks; + // First try to find a list + { + RiffContainer::ListChunk* listChunk = container->getRoot()->findListRec(SerialBinary::kSlangModuleListFourCc); + if (listChunk) + { + listChunk->findContained(ASTSerialBinary::kSlangASTModuleFourCC, moduleChunks); + } + else + { + // Maybe its just a single module + RiffContainer::ListChunk* moduleChunk = container->getRoot()->findListRec(ASTSerialBinary::kSlangASTModuleFourCC); + if (!moduleChunk) + { + // Couldn't find any modules + return SLANG_FAIL; + } + moduleChunks.add(moduleChunk); + } + } + + RefPtr<ASTSerialClasses> serialClasses(new ASTSerialClasses); + + // Okay, deserialize the each of the module chunks + for (RiffContainer::ListChunk* listChunk : moduleChunks) + { + // Look for the module data + auto data = listChunk->findContainedData(ASTSerialBinary::kSlangASTModuleDataFourCC); + + if (!data) + { + return SLANG_FAIL; + } + + ASTSerialReader reader(serialClasses); + + RefPtr<Module> module(new Module(linkage)); + SLANG_RETURN_ON_FAIL(reader.load((uint8_t*)data->getPayload(), data->getSize(), module->getASTBuilder(), linkage->getNamePool())); + + ModuleDecl* moduleDecl = reader.getPointer(ASTSerialIndex(1)).dynamicCast<ModuleDecl>(); + if (!moduleDecl) + { + return SLANG_FAIL; + } + + // Set on the module + module->setModuleDecl(moduleDecl); + + outModules.add(module); + } + + return SLANG_OK; +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerializeUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!! /* static */SlangResult ASTSerialTestUtil::selfTest() @@ -1762,7 +1971,12 @@ SlangResult ASTSerialReader::load(const uint8_t* data, size_t dataCount, ASTBuil { OwnedMemoryStream stream(FileAccess::ReadWrite); - ASTSerialWriter writer(classes); + ModuleDecl* moduleDecl = as<ModuleDecl>(node); + ModuleASTSerialFilter filterStorage(moduleDecl); + + ASTSerialFilter* filter = moduleDecl ? &filterStorage : nullptr; + + ASTSerialWriter writer(classes, filter); // Lets serialize it all writer.addPointer(node); @@ -1831,6 +2045,7 @@ SlangResult ASTSerialReader::load(const uint8_t* data, size_t dataCount, ASTBuil File::writeAllText("ast-read.ast-dump", readDump); File::writeAllText("ast-orig.ast-dump", origDump); + if (readDump != origDump) { return SLANG_FAIL; |
