summaryrefslogtreecommitdiff
path: root/source/slang/slang-ast-serialize.cpp
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2020-08-31 13:02:55 -0400
committerGitHub <noreply@github.com>2020-08-31 13:02:55 -0400
commit69025ad82238a7402b18d9c566fac1574faef684 (patch)
treeab01e4248071f9f597aa04f15742852a7662ed23 /source/slang/slang-ast-serialize.cpp
parentbaa789e0c9109bcb1e717ce4a9953709e7345e55 (diff)
AST Serialization in Modules (#1524)
* First pass at filter for AST serial writing. * Serialization of AST for modules. * Removed some commented out source. Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-ast-serialize.cpp')
-rw-r--r--source/slang/slang-ast-serialize.cpp251
1 files changed, 233 insertions, 18 deletions
diff --git a/source/slang/slang-ast-serialize.cpp b/source/slang/slang-ast-serialize.cpp
index 6eaa29968..22d5c3de2 100644
--- a/source/slang/slang-ast-serialize.cpp
+++ b/source/slang/slang-ast-serialize.cpp
@@ -8,9 +8,12 @@
#include "slang-type-layout.h"
#include "slang-ast-dump.h"
+#include "slang-mangle.h"
#include "slang-ast-support-types.h"
+#include "slang-legalize-types.h"
+
#include "../core/slang-byte-encode-util.h"
namespace Slang {
@@ -46,6 +49,60 @@ static void _toNativeValue(ASTSerialReader* reader, const SERIAL_TYPE& src, NATI
ASTSerialTypeInfo<NATIVE_TYPE>::toNative(reader, &src, &dst);
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ModuleASTSerialFilter !!!!!!!!!!!!!!!!!!!!!!!!
+
+
+ASTSerialIndex ModuleASTSerialFilter::writePointer(ASTSerialWriter* writer, const NodeBase* inPtr)
+{
+ NodeBase* ptr = const_cast<NodeBase*>(inPtr);
+ SLANG_ASSERT(ptr);
+
+ if (Decl* decl = as<Decl>(ptr))
+ {
+ ModuleDecl* moduleDecl = findModuleForDecl(decl);
+ SLANG_ASSERT(moduleDecl);
+ if (moduleDecl && moduleDecl != m_moduleDecl)
+ {
+ ASTBuilder* astBuilder = m_moduleDecl->module->getASTBuilder();
+
+ // It's a reference to a declaration in another module, so create an ImportExternalDecl.
+
+ String mangledName = getMangledName(astBuilder, decl);
+
+ ImportExternalDecl* importDecl = astBuilder->create<ImportExternalDecl>();
+ importDecl->mangledName = mangledName;
+ const ASTSerialIndex index = writer->writePointer(importDecl);
+
+ // Set as the index of this
+ writer->setPointerIndex(ptr, index);
+
+ return index;
+ }
+ else
+ {
+ // Okay... we can just write it out then
+ return writer->writePointer(ptr);
+ }
+ }
+
+ // TODO(JS): What we really want to do here is to ignore bodies functions.
+ // It's not 100% clear if this is even right though - for example does type inference
+ // imply the body is needed to say infer a return type?
+ // Also not clear if statements in other scenarios (if there are others) might need to be kept.
+ //
+ // For now we just ignore all stmts
+
+ if (Stmt* stmt = as<Stmt>(ptr))
+ {
+ //
+ writer->setPointerIndex(stmt, ASTSerialIndex(0));
+ return ASTSerialIndex(0);
+ }
+
+ // For now for everything else just write it
+ return writer->writePointer(ptr);
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Serial <-> Native conversion !!!!!!!!!!!!!!!!!!!!!!!!
@@ -1030,28 +1087,20 @@ ASTSerialClasses::ASTSerialClasses():
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerialWriter !!!!!!!!!!!!!!!!!!!!!!!!!!!!
-ASTSerialWriter::ASTSerialWriter(ASTSerialClasses* classes) :
+ASTSerialWriter::ASTSerialWriter(ASTSerialClasses* classes, ASTSerialFilter* filter) :
m_arena(2048),
- m_classes(classes)
+ m_classes(classes),
+ m_filter(filter)
{
// 0 is always the null pointer
m_entries.add(nullptr);
m_ptrMap.Add(nullptr, 0);
}
-ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node)
+ASTSerialIndex ASTSerialWriter::writePointer(const NodeBase* node)
{
- // Null is always 0
- if (node == nullptr)
- {
- return ASTSerialIndex(0);
- }
- // Look up in the map
- Index* indexPtr = m_ptrMap.TryGetValue(node);
- if (indexPtr)
- {
- return ASTSerialIndex(*indexPtr);
- }
+ // This pointer cannot be in the map
+ SLANG_ASSERT(m_ptrMap.TryGetValue(node) == nullptr);
const ASTSerialClass* serialClass = m_classes->getSerialClass(node->astNodeType);
@@ -1072,6 +1121,7 @@ ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node)
for (Index i = 0; i < serialClass->fieldsCount; ++i)
{
auto field = serialClass->fields[i];
+
// Work out the offsets
auto srcField = ((const uint8_t*)node) + field.nativeOffset;
auto dstField = serialPayload + field.serialOffset;
@@ -1088,6 +1138,35 @@ ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node)
return index;
}
+void ASTSerialWriter::setPointerIndex(const NodeBase* ptr, ASTSerialIndex index)
+{
+ m_ptrMap.Add(ptr, Index(index));
+}
+
+ASTSerialIndex ASTSerialWriter::addPointer(const NodeBase* node)
+{
+ // Null is always 0
+ if (node == nullptr)
+ {
+ return ASTSerialIndex(0);
+ }
+ // Look up in the map
+ Index* indexPtr = m_ptrMap.TryGetValue(node);
+ if (indexPtr)
+ {
+ return ASTSerialIndex(*indexPtr);
+ }
+
+ if (m_filter)
+ {
+ return m_filter->writePointer(this, node);
+ }
+ else
+ {
+ return writePointer(node);
+ }
+}
+
ASTSerialIndex ASTSerialWriter::addPointer(const RefObject* obj)
{
// Null is always 0
@@ -1246,6 +1325,8 @@ ASTSerialIndex ASTSerialWriter::_addArray(size_t elementSize, size_t alignment,
return ASTSerialIndex(m_entries.getCount() - 1);
}
+static const uint8_t s_fixBuffer[ASTSerialInfo::MAX_ALIGNMENT]{ 0, };
+
SlangResult ASTSerialWriter::write(Stream* stream)
{
const Int entriesCount = m_entries.getCount();
@@ -1263,8 +1344,6 @@ SlangResult ASTSerialWriter::write(Stream* stream)
// knowing that removeLast cannot release memory, means the sentinal must be at the last position.
entries[entriesCount] = &sentinal;
-
- static const uint8_t fixBuffer[ASTSerialInfo::MAX_ALIGNMENT] { 0, };
{
size_t offset = 0;
@@ -1303,7 +1382,7 @@ SlangResult ASTSerialWriter::write(Stream* stream)
// If we needed to fix so that subsequent alignment is right, write out extra bytes here
if (alignmentFixSize)
{
- stream->write(fixBuffer, alignmentFixSize);
+ stream->write(s_fixBuffer, alignmentFixSize);
}
}
catch (const IOException&)
@@ -1320,6 +1399,80 @@ SlangResult ASTSerialWriter::write(Stream* stream)
return SLANG_OK;
}
+SlangResult ASTSerialWriter::writeIntoContainer(RiffContainer* container)
+{
+ typedef RiffContainer::Chunk Chunk;
+ typedef RiffContainer::ScopeChunk ScopeChunk;
+
+ // This is the container for the AST Data
+ ScopeChunk scopeModule(container, Chunk::Kind::List, ASTSerialBinary::kSlangASTModuleFourCC);
+ {
+ ScopeChunk scopeData(container, Chunk::Kind::Data, ASTSerialBinary::kSlangASTModuleDataFourCC);
+
+ {
+ // Sentinal so we don't need special handling for end of list
+ ASTSerialInfo::Entry sentinal;
+ sentinal.type = ASTSerialInfo::Type::String;
+ sentinal.info = ASTSerialInfo::EntryInfo::Alignment1;
+
+ size_t offset = 0;
+ const Int entriesCount = m_entries.getCount();
+
+ {
+ m_entries.add(&sentinal);
+ m_entries.removeLast();
+ // Note strictly required in our impl of List. But by writing this and
+ // knowing that removeLast cannot release memory, means the sentinal must be at the last position.
+ m_entries.getBuffer()[entriesCount] = &sentinal;
+ }
+
+ ASTSerialInfo::Entry*const* entries = m_entries.getBuffer();
+
+ ASTSerialInfo::Entry* entry = entries[1];
+ // We start on 1, because 0 is nullptr and not used for anything
+ for (Index i = 1; i < entriesCount; ++i)
+ {
+ ASTSerialInfo::Entry* next = entries[i + 1];
+
+ // Before writing we need to store the next alignment
+
+ const size_t nextAlignment = ASTSerialInfo::getAlignment(next->info);
+ const size_t alignment = ASTSerialInfo::getAlignment(entry->info);
+
+ entry->info = ASTSerialInfo::combineWithNext(entry->info, next->info);
+
+ // Check we are aligned correctly
+ SLANG_ASSERT((offset & (alignment - 1)) == 0);
+
+ // When we write, we need to make sure it take into account the next alignment
+ const size_t entrySize = entry->calcSize(m_classes);
+
+ // Work out the fix for next alignment
+ size_t nextOffset = offset + entrySize;
+ nextOffset = (nextOffset + nextAlignment - 1) & ~(nextAlignment - 1);
+
+ size_t alignmentFixSize = nextOffset - (offset + entrySize);
+
+ // The fix must be less than max alignment. We require it to be less because we aligned each Entry to
+ // MAX_ALIGNMENT, and so < MAX_ALIGNMENT is the most extra bytes we can write
+ SLANG_ASSERT(alignmentFixSize < ASTSerialInfo::MAX_ALIGNMENT);
+
+ container->write(entry, entrySize);
+ if (alignmentFixSize)
+ {
+ container->write(s_fixBuffer, alignmentFixSize);
+ }
+
+ // Onto next
+ offset = nextOffset;
+ entry = next;
+ }
+ }
+ }
+
+ return SLANG_OK;
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerialInfo::Entry !!!!!!!!!!!!!!!!!!!!!!!!
size_t ASTSerialInfo::Entry::calcSize(ASTSerialClasses* serialClasses) const
@@ -1713,6 +1866,62 @@ SlangResult ASTSerialReader::load(const uint8_t* data, size_t dataCount, ASTBuil
}
+/* static */Result ASTSerialReader::readContainerModules(RiffContainer* container, Linkage* linkage, List<RefPtr<Module>>& outModules)
+{
+ List<RiffContainer::ListChunk*> moduleChunks;
+ // First try to find a list
+ {
+ RiffContainer::ListChunk* listChunk = container->getRoot()->findListRec(SerialBinary::kSlangModuleListFourCc);
+ if (listChunk)
+ {
+ listChunk->findContained(ASTSerialBinary::kSlangASTModuleFourCC, moduleChunks);
+ }
+ else
+ {
+ // Maybe its just a single module
+ RiffContainer::ListChunk* moduleChunk = container->getRoot()->findListRec(ASTSerialBinary::kSlangASTModuleFourCC);
+ if (!moduleChunk)
+ {
+ // Couldn't find any modules
+ return SLANG_FAIL;
+ }
+ moduleChunks.add(moduleChunk);
+ }
+ }
+
+ RefPtr<ASTSerialClasses> serialClasses(new ASTSerialClasses);
+
+ // Okay, deserialize the each of the module chunks
+ for (RiffContainer::ListChunk* listChunk : moduleChunks)
+ {
+ // Look for the module data
+ auto data = listChunk->findContainedData(ASTSerialBinary::kSlangASTModuleDataFourCC);
+
+ if (!data)
+ {
+ return SLANG_FAIL;
+ }
+
+ ASTSerialReader reader(serialClasses);
+
+ RefPtr<Module> module(new Module(linkage));
+ SLANG_RETURN_ON_FAIL(reader.load((uint8_t*)data->getPayload(), data->getSize(), module->getASTBuilder(), linkage->getNamePool()));
+
+ ModuleDecl* moduleDecl = reader.getPointer(ASTSerialIndex(1)).dynamicCast<ModuleDecl>();
+ if (!moduleDecl)
+ {
+ return SLANG_FAIL;
+ }
+
+ // Set on the module
+ module->setModuleDecl(moduleDecl);
+
+ outModules.add(module);
+ }
+
+ return SLANG_OK;
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerializeUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!!
/* static */SlangResult ASTSerialTestUtil::selfTest()
@@ -1762,7 +1971,12 @@ SlangResult ASTSerialReader::load(const uint8_t* data, size_t dataCount, ASTBuil
{
OwnedMemoryStream stream(FileAccess::ReadWrite);
- ASTSerialWriter writer(classes);
+ ModuleDecl* moduleDecl = as<ModuleDecl>(node);
+ ModuleASTSerialFilter filterStorage(moduleDecl);
+
+ ASTSerialFilter* filter = moduleDecl ? &filterStorage : nullptr;
+
+ ASTSerialWriter writer(classes, filter);
// Lets serialize it all
writer.addPointer(node);
@@ -1831,6 +2045,7 @@ SlangResult ASTSerialReader::load(const uint8_t* data, size_t dataCount, ASTBuil
File::writeAllText("ast-read.ast-dump", readDump);
File::writeAllText("ast-orig.ast-dump", origDump);
+
if (readDump != origDump)
{
return SLANG_FAIL;