From 1cf3f18a9ca1905a5bc51790ca723815dd5b1400 Mon Sep 17 00:00:00 2001 From: Theresa Foley <10618364+tangent-vector@users.noreply.github.com> Date: Tue, 22 Apr 2025 13:26:57 -0700 Subject: A new approach to AST serialization (#6854) * A new approach to AST serialization This change completely overhauls the way that AST nodes are being serialized, and the offline source-code generation steps that enable that serialization. In practice, this ends up being a complete overhaul of the way that *modules* are being serialized (not just the AST part), although things like the serialization format for the Slang IR and for source locations are not affected. The rest of this commit message is broken down in to sections, in an attempt to help guide anybody looking at the code in how to make sense of all the changes. The Old C++ Extractor --------------------- AST serialization used to be driven by information scraped using the `slang-cpp-extractor` tool, which did an ad hoc parse of the C++ declarations of the AST node types and then generated a set of "X macros" that could be for macro-based code generation within the rest of the compiler. While the existing approach was functional, it wasn't easy to understand or maintain, and it has been getting in the way of forward progress on other features we'd like to work on in the language and compiler. This change removes the `slang-cpp-extractor` tool entirely. Marking Up the AST Declarations ------------------------------- The most notable change that contributors to the compiler may notice is the large number of invocations of a macro `FIDDLE()` on the declarations of the AST node types. The basic idea is that only declarations (namespaces, types, fields) that are preceded by `FIDDLE()` are visible to the code generator tool. So if somebody is working with the AST and wondering why a new node type isn't working, or why a field they added isn't being serialized correctly, it is probably because they need to add `FIDDLE()` in front of it. Generating the Boilerplate Code ------------------------------- The file `slang-ast-boilerplate.cpp` provides a good example of how the information extracted from the marked-up AST declarations gets used. In that file, the `FIDDLE TEMPLATE` construct is used to generate type information for each of the AST node types. Similar logic is used in `slang-ast-forward-declarations.h` to generate the declaration of the `ASTNodeType` enumeration, and forward-declare all the AST node classes. For many parts of the code, simply including that file replaces the need for the old `slang-generated-*.h` files. Replacing Visitors and Related Logic ------------------------------------ The old visitor types for the AST used the macros that were generated by `slang-cpp-extractor`, so something new was needed to replace them. The same goes for the `SLANG_AST_NODE_VIRTUAL_CALL` macros. The core of the solution implemented here is in `slang-ast-dispatch.h`. Given a "dispatchable" AST node type (say, `Expr`), a call like: ``` ASTNodeDispatcher(expr, [&](auto e) { return doSomething(e); }) ``` is an expression of type `R`, which does the equivalent of something like: ``` switch(expr->getTag()) { case ASTNodeType::VarExpr: return doSomething(static_cast(expr)); // ... } ``` The `SLANG_AST_NODE_VIRTUAL_CALL` macro is now implemented in terms of `ASTNodeDispatcher`. The implementation of the visitor types is more involved. The code in this change retains some of the macro names from the original version, just to try and make the parallels more clear. The visitor types are all implemented on top of the `ASTNodeDispatcher` approach, and use `FIDDLE TEMPLATE` to generate all the boilerplate `visit*()` method declarations. Refactoring of `Linkage` Module Loading --------------------------------------- Needing to revisit all the places where modules get deserialized made it clear that there is a lot of complexity and apparent duplication in the core routines on the `Linkage` that get used for loading modules. This change tries to clean up some of that logic, but it is worth noting that there are two legacy features that get in the way of making things as clean as they should be: * The `LoadedModuleDictionary` type that gets passed around a lot exists entirely to handle the corner case where somebody uses the Slang API to perform a compilation with multiple `TranslationUnitRequest`s in the same `FrontEndCompileRequest`, and one of the translation units `import`s the module defined by another of the translation units. * There are a lot of special-case behaviors and routines entirely there to support the `ModuleLibrary` feature, although that feature should be considered deprecated (or at least subject to getting entirely re-designed down the line). The basic idea of the cleanup is that all of the (non-deprecated) ways load a module from a serialized binary, or compile one from source should now bottleneck through `loadModuleImpl`, which then bifurcates into `loadSourceModuleImpl` for the compilation case and `loadBinaryModuleImpl` for the deserialization case. High-Level Serialization Approach --------------------------------- The old serialization logic used the [RIFF](https://en.wikipedia.org/wiki/Resource_Interchange_File_Format) format to encode the high-level structure of things, and this change retains that usage (and actually doubles down on the RIFF usage). The old serialization system relied on the idea that for any given type `Foo` that wants to support serialization, there should be something like a `SerialFooData` type in C++, that can represent the state of a `Foo`, and then the actual serialization applied to that `SerialFooData`. This means that in most cases there are four pieces of code written: * During serialization: * Copying the data of a `Foo` in memory over to a `SerialFooData` in memory * Writing the state of a `SerialFooData` into the serialized data stream * During deserialization: * Reading the state of a `SerialFooData` from a serialized data stream * Copying the data of the `SerialFooData` in memory over to a `Foo` The new logic gets rid of the intermediate `SerialFooData`. In the serialization direction, we take a `Foo` and write it to the `RIFFContainer` directly, or using some other utilities layered on top of it. In the deserialization direction, we have additional flexibility. Given a `RIFFContainer::Chunk*` that represents a serialized `Foo`, we often navigate through the in-memory representation of the RIFF data to get to the parts of the serialized value that we actually want/need, without needing to deserialize the entire `Foo`. To support this kind of operation, this change introduces a few helper types like `ContainerChunkRef` an `ModuleChunkRef`, that are little more than typed wrappers around a `RIFFContainer::Chunk*`. The Module "Container" Part --------------------------- A serialized `Module` is encoded as a RIFF chunk, using logic in `slang-serialize-container.cpp` - both before and after this change. This change reorganizes a lot of the code in that file, to account for the way that eliminating the intermediate `SerialContainerData` type streamlines the overall task of writing out the parts of the module. In the deserialization logic... there isn't really much to do in `slang-serialize-container.cpp`. Most of the logic in `slang.cpp` and `slang-module-library.cpp` that pertains to deserializing modules uses the `ModuleChunkRef`-based approach, and simply extracts the pieces of the serialized module that it needs. The Actual Serialization of the AST ----------------------------------- The actual AST serialization logic is in `slang-serialize-ast.cpp`. The basic approach in both the writing and reading directions is: * Use the `FIDDLE TEMPLATE` system to generate a set of functions, one for each AST node type, that recursively invoke the read/write logic on each field of that node (after recursively invoking the case for its direct superclass) * Use the `ASTNodeDispatcher` system to dispatch out to those functions whene reading or writing anything derived from `NodeBase` * For now, handle all types *not* derived from `NodeBase` by hand. There's a lot of room for improvement around that last item: it should be just as easy to generate the serialization and deserialization logic for other types that don't inherit from `NodeBase`, but the current change tries to err on the side of making the logic as explicit and simplistic as possible, rather than trying to get too clever too soon. The actual serialization *format* used for the AST is almost comically simplistic: the code uses hierarchical RIFF chunks to emulate a JSON-like structure. This is a very wasteful representation (e.g., a `bool` or a null pointer each take up *8 bytes*), but the goal for now is to start with the simplest thing that could possibly work, and only add more cleverness once we are sure it won't get in the way of important future improvements (like lazy/on-demand deserialization or IR and AST, to improve compiler startup times). The files `slang-serialize.{h,cpp}` have been co-opted to define a new pair of types `Encoder` and `Decoder` that are used for a more-or-less stream-oriented way or reading or writing RIFF chunks for the JSON-like structure. Almost everything related to the actual AST serialization could do with a cleanup pass, and some time spent on picking good/better names for everything. Smaller Stuff ------------- * Cleaned up a lot of code that was using bare `ASTNodeType` or the extractor's `ReflectClassInfo` type to consistently use `SyntaxClass`. * Fixed an apparent bug in how the destination-driven code genarator was handling `TryExpr`s * Fixed an apparent bug in how the GLSL legalization pass was handling translation of certain `SV_*` semantics. * format code * fixup: template errors caught by non-VS compilers * format code * fixup: more template errors * fixup: more stuff VS didn't catch * fixup: it's amazing VS doesn't catch these... * fixup: yet more template stuff VS ignores * fixup: more VS template nonsense * fixup: unreachable return macro usage * fixup: more unreacable returns * fixup: unused parameter * fixup: strict aliasing * fixup: allow missing entry point list chunk * fixup: wasm build script * fixup: AST changes since this PR was created --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> Co-authored-by: Yong He --- source/slang/slang-serialize-ast.cpp | 1612 +++++++++++++++++++++++++++++++--- 1 file changed, 1473 insertions(+), 139 deletions(-) (limited to 'source/slang/slang-serialize-ast.cpp') diff --git a/source/slang/slang-serialize-ast.cpp b/source/slang/slang-serialize-ast.cpp index a7837edea..aad3bcc57 100644 --- a/source/slang/slang-serialize-ast.cpp +++ b/source/slang/slang-serialize-ast.cpp @@ -1,208 +1,1542 @@ // slang-serialize-ast.cpp #include "slang-serialize-ast.h" -#include "slang-ast-dump.h" -#include "slang-ast-support-types.h" -#include "slang-generated-ast-macro.h" -#include "slang-generated-ast.h" -#include "slang-serialize-ast-type-info.h" -#include "slang-serialize-factory.h" +#include "slang-ast-dispatch.h" +#include "slang-compiler.h" +#include "slang-diagnostics.h" +#include "slang-mangle.h" namespace Slang { +// TODO(tfoley): have the parser export this, or a utility function +// for initializing a `SyntaxDecl` in the common case. +// +NodeBase* parseSimpleSyntax(Parser* parser, void* userData); -// !!!!!!!!!!!!!!!!!!!!!! Generate fields for a type !!!!!!!!!!!!!!!!!!!!!!!!!!! -static const SerialClass* _addClass( - SerialClasses* serialClasses, - ASTNodeType type, - ASTNodeType super, - const List& fields) +struct ASTEncodingContext { - const SerialClass* superClass = - serialClasses->getSerialClass(SerialTypeKind::NodeBase, SerialSubType(super)); - return serialClasses->add( - SerialTypeKind::NodeBase, - SerialSubType(type), - fields.getBuffer(), - fields.getCount(), - superClass); -} +private: + Encoder* encoder; + struct UnhandledCase + { + }; + + typedef Int DeclID; + Dictionary mapDeclToID; + List decls; + + struct ImportedDeclInfo + { + Int moduleIndex = -1; + Decl* decl; + }; + List importedDecls; + + typedef Int ValID; + Dictionary mapValToID; + List vals; -#define SLANG_AST_ADD_SERIAL_FIELD(FIELD_NAME, TYPE, param) \ - fields.add(SerialField::make(#FIELD_NAME, &obj->FIELD_NAME)); + ModuleDecl* _module = nullptr; -// Note that the obj point is not nullptr, because some compilers notice this is 'indexing from -// null' and warn/error. So we offset from 1. -#define SLANG_AST_ADD_SERIAL_CLASS(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) \ - { \ - NAME* obj = SerialField::getPtr(); \ - SLANG_UNUSED(obj); \ - fields.clear(); \ - SLANG_FIELDS_ASTNode_##NAME(SLANG_AST_ADD_SERIAL_FIELD, param) \ - _addClass(serialClasses, ASTNodeType::NAME, ASTNodeType::SUPER, fields); \ + SerialSourceLocWriter* _sourceLocWriter = nullptr; + +public: + ASTEncodingContext(Encoder* encoder, ModuleDecl* module, SerialSourceLocWriter* sourceLocWriter) + : encoder(encoder), _module(module), _sourceLocWriter(sourceLocWriter) + { } -struct ASTFieldAccess -{ - static void calcClasses(SerialClasses* serialClasses) + template + void encodeASTNodeContent(T* node) { - // Add NodeBase first, and specially handle so that we add a null super class - serialClasses->add( - SerialTypeKind::NodeBase, - SerialSubType(ASTNodeType::NodeBase), - nullptr, - 0, - nullptr); + Encoder::WithObject withObject(encoder); - // Add the rest in order such that Super class is always added before its children - List fields; - SLANG_CHILDREN_ASTNode_NodeBase(SLANG_AST_ADD_SERIAL_CLASS, _) + ASTNodeDispatcher::dispatch(node, [&](auto n) { _encodeDataOf(n); }); } -}; -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTSerialUtil !!!!!!!!!!!!!!!!!!!!!!!!!!!! + void flush() + { + auto containerChunk = encoder->getRIFFChunk(); -/* static */ void ASTSerialUtil::addSerialClasses(SerialClasses* serialClasses) -{ - ASTFieldAccess::calcClasses(serialClasses); -} + RiffContainer::Chunk* declChunk = nullptr; + RiffContainer::Chunk* importedDeclChunk = nullptr; + RiffContainer::Chunk* valChunk = nullptr; + { + Encoder::WithArray withList(encoder); + declChunk = encoder->getRIFFChunk(); + } + { + Encoder::WithArray withList(encoder); + importedDeclChunk = encoder->getRIFFChunk(); + } + { + Encoder::WithArray withList(encoder); + valChunk = encoder->getRIFFChunk(); + } + Int declIndex = 0; + Int importedDeclIndex = 0; + Int valIndex = 0; -/* static */ SlangResult ASTSerialUtil::testSerialize( - NodeBase* node, - RootNamePool* rootNamePool, - SharedASTBuilder* sharedASTBuilder, - SourceManager* sourceManager) -{ - RefPtr classes; + bool done = false; + do + { + done = true; + while (declIndex < decls.getCount()) + { + done = false; + encoder->setRIFFChunk(declChunk); + encodeASTNodeContent(decls[declIndex++]); + } + while (importedDeclIndex < importedDecls.getCount()) + { + done = false; + encoder->setRIFFChunk(importedDeclChunk); + encodeImportedDecl(importedDecls[importedDeclIndex++]); + } + while (valIndex < vals.getCount()) + { + done = false; + encoder->setRIFFChunk(valChunk); + encodeASTNodeContent(vals[valIndex++]); + } + } while (!done); + + RiffContainer::calcAndSetSize(containerChunk); + encoder->setRIFFChunk(containerChunk); + } - SerialClassesUtil::create(classes); + ModuleDecl* findModuleForDecl(Decl* decl) + { + for (auto d = decl; d; d = d->parentDecl) + { + if (auto m = as(d)) + return m; + } + return nullptr; + } - List contents; + ModuleDecl* findModuleDeclWasImportedFrom(Decl* decl) + { + auto declModule = findModuleForDecl(decl); + if (declModule == nullptr) + return nullptr; + if (declModule == _module) + return nullptr; + return declModule; + } + DeclID getDeclID(Decl* decl) { - OwnedMemoryStream stream(FileAccess::ReadWrite); + SLANG_ASSERT(decl != nullptr); + + if (auto found = mapDeclToID.tryGetValue(decl)) + return *found; + + // We need to detect whether the declaration is an + // imported one, or one from this module itself. + // + // Imported declarations need to be handled very + // differently, since they'll involve resolving + // references to those other modules, and the + // declarations within them. + // + if (auto importedFromModule = findModuleDeclWasImportedFrom(decl)) + { + DeclID importedFromModuleDeclID = 0; + if (decl != importedFromModule) + { + importedFromModuleDeclID = getDeclID(importedFromModule); + } - ModuleDecl* moduleDecl = as(node); - // Only serialize out things *in* this module - ModuleSerialFilter filterStorage(moduleDecl); + DeclID id = ~importedDecls.getCount(); + mapDeclToID.add(decl, id); - SerialFilter* filter = moduleDecl ? &filterStorage : nullptr; + ImportedDeclInfo info; + info.moduleIndex = ~importedFromModuleDeclID; + info.decl = decl; + importedDecls.add(info); - SerialWriter writer(classes, filter); + return id; + } + else + { + DeclID id = decls.getCount(); + decls.add(decl); + mapDeclToID.add(decl, id); + + return id; + } + } + + void encodePtr(Decl* decl) + { + DeclID id = getDeclID(decl); + encoder->encode(id); + } - // Lets serialize it all - writer.addPointer(node); - // Let's stick it all in a stream - writer.write(&stream); + ValID getValID(Val* val) + { + SLANG_ASSERT(val != nullptr); - stream.swapContents(contents); + if (auto found = mapValToID.tryGetValue(val)) + return *found; - NamePool namePool; - namePool.setRootNamePool(rootNamePool); + // In order to ensure that values can be fully constructed + // from the get-go (so that they will get cached correctly), + // we conspire to ensure that every value is preceded by + // all of its operands. + // + for (auto operand : val->m_operands) + { + switch (operand.kind) + { + default: + break; - ASTBuilder builder(sharedASTBuilder, "Serialize Check"); + case ValNodeOperandKind::ValNode: + if (auto operandNode = operand.values.nodeOperand) + { + SLANG_ASSERT(as(operandNode)); + getValID(static_cast(operandNode)); + } + break; - SetASTBuilderContextRAII astBuilderRAII(&builder); + case ValNodeOperandKind::ASTNode: + if (auto operandNode = operand.values.nodeOperand) + { + SLANG_ASSERT(as(operandNode)); + getDeclID(static_cast(operandNode)); + } + break; + } + } + auto resolved = val->resolve(); + if (resolved != val) + { + getValID(resolved); + } - DefaultSerialObjectFactory objectFactory(&builder); + ValID id = vals.getCount(); + vals.add(val); + mapValToID.add(val, id); + return id; + } - // We could now check that the loaded data matches + void encodePtr(Val* val) + { + ValID id = getValID(val); + encoder->encode(id); + } + void encodeImportedDecl(ImportedDeclInfo const& info) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(info.moduleIndex); + auto decl = info.decl; + if (auto importedModuleDecl = as(decl)) + { + SLANG_ASSERT(info.moduleIndex == -1); + encode(importedModuleDecl->getName()); + } + else { - const List& writtenEntries = writer.getEntries(); - List readEntries; + auto mangledName = getMangledName(getCurrentASTBuilder(), decl); + encode(mangledName); + } + } - SlangResult res = SerialReader::loadEntries( - contents.getBuffer(), - contents.getCount(), - classes, - readEntries); - SLANG_UNUSED(res); + void encodePtr(Modifier* modifier) { encodeASTNodeContent(modifier); } + void encodePtr(Expr* expr) { encodeASTNodeContent(expr); } + void encodePtr(Stmt* stmt) { encodeASTNodeContent(stmt); } - SLANG_ASSERT(writtenEntries.getCount() == readEntries.getCount()); + void encodePtr(Name* name) { encode(name->text); } - // They should be identical up to the - for (Index i = 1; i < readEntries.getCount(); ++i) - { - auto writtenEntry = writtenEntries[i]; - auto readEntry = readEntries[i]; + void encodePtr(MarkupEntry* entry) + { + // TODO: is this case needed? + SLANG_UNUSED(entry); + } + + void encodePtr(DeclAssociationList* list) + { + // We serialize this as if it were a simple list + // of key-value pairs because... well... that's + // what it amounts to in practice. + // + Encoder::WithArray withArray(encoder); + for (auto association : list->associations) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(association->kind); + encode(association->decl); + } + } + + void encodePtr(CandidateExtensionList* list) { encode(list->candidateExtensions); } + + void encodePtr(WitnessTable* witnessTable) + { + Encoder::WithObject withObject(encoder); + encode(witnessTable->baseType); + encode(witnessTable->witnessedType); + encode(witnessTable->isExtern); + + // TODO(tfoley): In theory we should be able to streamline + // this so that we only encode the requirements that we + // absolutely need to (which basically amounts to `associatedtype` + // requirements where the satisfying type is part of the public + // API of the type). + // + encode(witnessTable->m_requirementDictionary); + } + + void encodeValue(RequirementWitness const& witness) + { + Encoder::WithKeyValuePair withPair(encoder); + encodeEnum(witness.m_flavor); + switch (witness.m_flavor) + { + case RequirementWitness::Flavor::none: + break; + + case RequirementWitness::Flavor::declRef: + encode(witness.m_declRef); + break; + + case RequirementWitness::Flavor::val: + encode(witness.m_val); + break; + + case RequirementWitness::Flavor::witnessTable: + encode((WitnessTable*)witness.m_obj.Ptr()); + break; + } + } + + void encodePtr(DiagnosticInfo* info) { encode(Int(info->id)); } + + void encodePtr(DeclBase* declBase) + { + if (auto decl = as(declBase)) + { + encodePtr(decl); + } + else + { + encodeASTNodeContent(declBase); + } + } + + void encodeValue(UnhandledCase); + + void encodeValue(String const& value) { encoder->encode(value); } + + void encodeValue(Token const& value) + { + encode(value.type); + encode(TokenFlags(value.flags & ~TokenFlag::Name)); + encode(value.loc); + if (value.hasContent()) + encoder->encodeString(value.getContent()); + else + encode(nullptr); + } + + void encodeValue(NameLoc const& value) { encode(value.name); } + + void encodeValue(SemanticVersion value) { encoder->encode(value.toInteger()); } + + void encodeValue(CapabilitySet const& value) + { + // While the `CapabilityTargetSets` type is a dictionary, + // in practice each entry already embeds its own key + // (the target atom), so we can encode this as just + // an array of the `CapabilityTargetSet` values. + // + Encoder::WithArray withArray(encoder); + for (auto pair : value.getCapabilityTargetSets()) + { + encode(pair.second); + } + } + + void encodeValue(CapabilityTargetSet const& value) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(value.target); + + // Similar to the case for the `CapabilityTargetSets` above, + // each `CapabilityStageSet` already includes the stage atom, + // so we can simply encode the values from the dictionary. + // + Encoder::WithArray withArray(encoder); + for (auto pair : value.shaderStageSets) + { + encode(pair.second); + } + } + + void encodeValue(CapabilityStageSet const& value) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(value.stage); + encode(value.atomSet); + } + + void encodeValue(CapabilityAtomSet const& value) + { + Encoder::WithArray withArray(encoder); + for (auto rawAtom : value) + { + encode(CapabilityAtom(rawAtom)); + } + } + + template + void encodeValue(std::optional const& value) + { + if (value) + encodeValue(*value); + else + encoder->encode(nullptr); + } + + void encodeValue(SyntaxClass const& value) { encode(value.getTag()); } + + template + void encodeValue(DeclRef const& value) + { + encode((DeclRefBase*)value); + } + + void encodeValue(ValNodeOperand value) + { + Encoder::WithKeyValuePair withPair(encoder); + + encodeEnum(value.kind); + switch (value.kind) + { + case ValNodeOperandKind::ConstantValue: + encode(value.values.intOperand); + break; - const size_t writtenSize = writtenEntry->calcSize(classes); - const size_t readSize = readEntry->calcSize(classes); - SLANG_UNUSED(writtenSize); - SLANG_UNUSED(readSize); + case ValNodeOperandKind::ValNode: + encode(static_cast(value.values.nodeOperand)); + break; - SLANG_ASSERT(readSize == writtenSize); - // Check the payload is the same - SLANG_ASSERT(memcmp(readEntry, writtenEntry, readSize) == 0); + case ValNodeOperandKind::ASTNode: + { + if (auto decl = as(value.values.nodeOperand)) + { + encode(decl); + } + else + { + SLANG_UNEXPECTED("AST node operand of `Val` was expected to be a `Decl`"); + } } + break; + } + } + + void encodeValue(TypeExp value) { encode(value.type); } + + void encodeValue(QualType value) + { + Encoder::WithObject withObject(encoder); + encode(value.type); + encode(value.isLeftValue); + encode(value.hasReadOnlyOnTarget); + encode(value.isWriteOnly); + } + + void encodeValue(MatrixCoord value) + { + Encoder::WithObject withObject(encoder); + encode(value.row); + encode(value.col); + } + + void encodeValue(SPIRVAsmOperand::Flavor const& value) { encodeEnum(value); } + + void encodeValue(SPIRVAsmOperand const& value) + { + Encoder::WithObject withObject(encoder); + encode(value.flavor); + encode(value.token); + encode(value.expr); + encode(value.bitwiseOrWith); + encode(value.knownValue); + encode(value.wrapInId); + encode(value.type); + } + + void encodeValue(SPIRVAsmInst const& value) + { + Encoder::WithObject withObject(encoder); + encode(value.opcode); + encode(value.operands); + } + + + template>> + void encodeValue(T value) + { + encoder->encodeBool(value); + } + + void encodeValue(Int32 value) { encoder->encode(value); } + void encodeValue(UInt32 value) { encoder->encode(value); } + void encodeValue(Int64 value) { encoder->encode(value); } + void encodeValue(UInt64 value) { encoder->encode(value); } + void encodeValue(float value) { encoder->encode(value); } + void encodeValue(double value) { encoder->encode(value); } + + void encodeValue(uint8_t value) { encoder->encode(UInt32(value)); } + + void encodeValue(nullptr_t) { encoder->encode(nullptr); } + + template + void encodeEnum(T value) + { + encoder->encode(Int32(value)); + } + + void encodeValue(DeclVisibility value) { encodeEnum(value); } + void encodeValue(BaseType value) { encodeEnum(value); } + void encodeValue(BuiltinRequirementKind value) { encodeEnum(value); } + void encodeValue(ASTNodeType value) { encodeEnum(value); } + void encodeValue(ImageFormat value) { encodeEnum(value); } + void encodeValue(TypeTag value) { encodeEnum(value); } + void encodeValue(TryClauseType value) { encodeEnum(value); } + void encodeValue(CapabilityAtom value) { encodeEnum(value); } + void encodeValue(DeclAssociationKind value) { encodeEnum(value); } + void encodeValue(TokenType value) { encodeEnum(value); } + + void encodeValue(SourceLoc value) + { + if (!_sourceLocWriter) + { + encoder->encode(nullptr); } + else + { + auto intermediate = _sourceLocWriter->addSourceLoc(value); + encoder->encode(intermediate); + } + } + + template + void encodeValue(T const* ptr) + { + if (!ptr) + { + encoder->encode(nullptr); + } + else + { + encodePtr(const_cast(ptr)); + } + } + + template + void encodeValue(RefPtr const& ptr) + { + if (!ptr) + { + encoder->encode(nullptr); + } + else + { + encodePtr(ptr.Ptr()); + } + } - SerialReader reader(classes, nullptr); + void encodeValue(Modifiers const& modifiers) + { + Encoder::WithArray withArray(encoder); + for (auto m : const_cast(modifiers)) { + encode(m); + } + } - SlangResult res = reader.load(contents.getBuffer(), contents.getCount(), &namePool); - SLANG_UNUSED(res); + template + void encodeValue(ShortList const& array) + { + Encoder::WithArray withArray(encoder); + for (auto element : array) + { + encode(element); } + } - // Lets see what we have - const ASTDumpUtil::Flags dumpFlags = - ASTDumpUtil::Flag::HideSourceLoc | ASTDumpUtil::Flag::HideScope; - String readDump; + template + void encode(List const& array) + { + Encoder::WithArray withArray(encoder); + for (auto element : array) { - SourceWriter sourceWriter(sourceManager, LineDirectiveMode::None, nullptr); - ASTDumpUtil::dump( - reader.getPointer(SerialIndex(1)).dynamicCast(), - ASTDumpUtil::Style::Hierachical, - dumpFlags, - &sourceWriter); - readDump = sourceWriter.getContentAndClear(); + encode(element); } - String origDump; + } + + template + void encode(T const (&array)[N]) + { + Encoder::WithArray withArray(encoder); + for (auto element : array) { - SourceWriter sourceWriter(sourceManager, LineDirectiveMode::None, nullptr); - ASTDumpUtil::dump(node, ASTDumpUtil::Style::Hierachical, dumpFlags, &sourceWriter); - origDump = sourceWriter.getContentAndClear(); + encode(element); } + } - // Write out - File::writeAllText("ast-read.ast-dump", readDump); - File::writeAllText("ast-orig.ast-dump", origDump); + template + void encode(OrderedDictionary const& dictionary) + { + Encoder::WithArray withArray(encoder); + for (auto p : dictionary) + { + Encoder::WithKeyValuePair withPair(encoder); + encode(p.key); + encode(p.value); + } + } - if (readDump != origDump) + template + void encode(Dictionary const& dictionary) + { + Encoder::WithArray withArray(encoder); + for (auto p : dictionary) { - return SLANG_FAIL; + Encoder::WithKeyValuePair withPair(encoder); + encode(p.first); + encode(p.second); } } - return SLANG_OK; + template + void encode(T const& value) + { + encodeValue(value); + } + + // for each class of node, we generate + // code to recursively serialize each + // of its fields. + +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do + void _encodeDataOf($T* obj) + { +%if T.directSuperClass then + _encodeDataOf(static_cast<$(T.directSuperClass)*>(obj)); +%end +%for _,f in ipairs(T.directFields) do + encode(obj->$f); +%end + } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-serialize-ast.cpp.fiddle" +#endif // FIDDLE END +}; + +void writeSerializedModuleAST( + Encoder* encoder, + ModuleDecl* moduleDecl, + SerialSourceLocWriter* sourceLocWriter) +{ + Encoder::WithObject withObject(encoder); + + // TODO: we should have a more careful pass here, + // where we only encode the public declarations + // + + ASTEncodingContext context(encoder, moduleDecl, sourceLocWriter); + context.getDeclID(moduleDecl); + context.flush(); } -/* static */ List ASTSerialUtil::serializeAST(ModuleDecl* moduleDecl) +struct ASTDecodingContext { - // TODO: we should store `classes` in GlobalSession to avoid recomputing them every time. - RefPtr classes; - SerialClassesUtil::create(classes); +public: + ASTDecodingContext( + Linkage* linkage, + ASTBuilder* astBuilder, + DiagnosticSink* sink, + RiffContainer::Chunk* rootChunk, + SerialSourceLocReader* sourceLocReader, + SourceLoc requestingSourceLoc) + : _linkage(linkage) + , _astBuilder(astBuilder) + , _sink(sink) + , _rootChunk(static_cast(rootChunk)) + , _sourceLocReader(sourceLocReader) + , _requestingSourceLoc(requestingSourceLoc) + { + } - List contents; - OwnedMemoryStream stream(FileAccess::ReadWrite); + Linkage* _linkage = nullptr; + DiagnosticSink* _sink = nullptr; + SerialSourceLocReader* _sourceLocReader = nullptr; + SourceLoc _requestingSourceLoc; - // Only serialize out things *in* this module - ModuleSerialFilter filterStorage(moduleDecl); + SlangResult decodeAll() + { + auto cursor = _rootChunk->getFirstContainedChunk(); - SerialFilter* filter = moduleDecl ? &filterStorage : nullptr; + // There are a few different top-level chunks that + // hold different arrays that we need in order + // to decode the entire module hierarchy. + // + // Basically, these lists correspond to the kinds + // of nodes in the AST hierarchy for which back-references + // are allowed (all other nodes should, barring + // weird corner cases, form a single tree-structured + // ownership hierarchy, rooted at the `ModuleDecl`. + // - SerialWriter writer(classes, filter); + // First there is the list that actually encodes + // for the declarations in the module, including + // the `ModuleDecl` itself, which should be the + // first entry in the list. + // + auto declChunk = cursor; + cursor = cursor->m_next; - // Lets serialize it all - writer.addPointer(moduleDecl); - // Let's stick it all in a stream - writer.write(&stream); + // Next there is a list of all the declarations + // referenced inside of the module that need to + // be imported in from outside. + // + auto importedDeclChunk = cursor; + cursor = cursor->m_next; - stream.swapContents(contents); - return contents; -} + // Then there are all the `Val`-derived nodes that + // are needed by the module, which will need to be + // deduplicated so that they are unique within the + // current compilation context. + // + auto valChunk = cursor; + cursor = cursor->m_next; + + // The process of decoding the module is then spread + // over a number of steps. + // + // The first step is to process all of the imported + // declarations, so that other nodes can refer to + // them. + // + SLANG_RETURN_ON_FAIL(decodeImportedDecls(importedDeclChunk)); + + // Next we process the declarations that are within + // the module itself, first creating an "empty shell" + // of each declaration that has the right size in + // memory (and the right `ASTNodeType` tag), so that + // we can wire up references to it (including circular + // references)... so long as nothing here tries to + // look *inside* the empty shell along the way. + // + SLANG_RETURN_ON_FAIL(createEmptyShells(declChunk)); + + // Once all the `Decl`s that might be needed have + // been allocated, we can process all the `Val`s + // that might reference those`Decl`s (and one another). + // + // The nature of the `Val` representation ensures + // that there cannot be cirularities in the references + // between `Val`s, and the encoding process will have + // sorted the entries so that a `Val` only ever appears + // *after* its operands. + // + SLANG_RETURN_ON_FAIL(decodeVals(valChunk)); + + // Once all the back-reference-able objects have been + // instantiated in memory, we can go back through the + // `Decl`s in the module and fill in those empty shells. + // + SLANG_RETURN_ON_FAIL(fillEmptyShells(declChunk)); + + // As a final pass, we perform any special cleanup actions + // that might be required to make the output valid for consumers. + // + // For example, this is where we set the `DeclCheckState` of everything + // we are loading to reflect the fact that everything we deserialize + // is (supposed to be) fully cheked. + // + SLANG_RETURN_ON_FAIL(cleanUpNodes()); + + + return SLANG_OK; + } + + typedef Int DeclID; + Decl* getDeclByID(DeclID id) + { + if (id >= 0) + { + return _decls[id]; + } + else + { + return _importedDecls[~id]; + } + } + +private: + struct UnhandledCase + { + }; + + ASTBuilder* _astBuilder = nullptr; + RiffContainer::ListChunk* _rootChunk = nullptr; + + List _decls; + List _importedDecls; + List _vals; + + typedef Int ValID; + Val* getValByID(ValID id) { return _vals[id]; } + + SlangResult decodeImportedDecls(RiffContainer::Chunk* importedDeclChunk) + { + Decoder decoder(importedDeclChunk); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Decoder::WithKeyValuePair withPair(decoder); + + Int moduleIndex; + decode(moduleIndex, decoder); + + if (moduleIndex == -1) + { + Name* moduleName = nullptr; + decode(moduleName, decoder); + + Decl* importedModule = getImportedModule(moduleName); + _importedDecls.add(importedModule); + } + else + { + auto importedFromModuleDecl = as(_importedDecls[moduleIndex]); + auto importedFromModule = importedFromModuleDecl->module; + + String mangledName; + decode(mangledName, decoder); + + auto importedNode = + importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); + auto importedDecl = as(importedNode); + _importedDecls.add(importedDecl); + } + } + return SLANG_OK; + } + + ModuleDecl* getImportedModule(Name* moduleName) + { + Module* module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); + if (!module) + { + SLANG_ABORT_COMPILATION("failed to load an imported module during deserialization"); + } + + return module->getModuleDecl(); + } + + SlangResult decodeVals(RiffContainer::Chunk* valChunk) + { + Decoder decoder(valChunk); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Val* val = decodeValNode(decoder); + _vals.add(val); + } + return SLANG_OK; + } + SlangResult createEmptyShells(RiffContainer::Chunk* declChunk) + { + Decoder decoder(declChunk); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + ASTNodeType nodeType; + + // Each of the declarations is expected to take + // the form of an object with a first field + // that holds the node type. + // + { + Decoder::WithObject withObject(decoder); + decode(nodeType, decoder); + } + + auto emptyShell = createEmptyShell(nodeType); + auto declEmptyShell = as(emptyShell); + _decls.add(declEmptyShell); + } + + return SLANG_OK; + } + + Val* decodeValNode(Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + + ASTNodeType nodeType; + decode(nodeType, decoder); + + ValNodeDesc desc; + desc.type = SyntaxClass(nodeType); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + ValNodeOperand operand; + decode(operand, decoder); + desc.operands.add(operand); + } + + desc.init(); + + auto val = _astBuilder->_getOrCreateImpl(_Move(desc)); + + // Values created during deserialization are + // not expected to ever resolve further, because + // they should be coming from fully checked code. + // + // val->resolve(); + // val->_setUnique(); + + return val; + } + + NodeBase* createEmptyShell(ASTNodeType nodeType) + { + return SyntaxClass(nodeType).createInstance(_astBuilder); + } + + SlangResult fillEmptyShells(RiffContainer::Chunk* declChunk) + { + Index declIndex = 0; + + Decoder decoder(declChunk); + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + auto declEmptyShell = _decls[declIndex++]; + decodeASTNodeContent(declEmptyShell, decoder); + } + + return SLANG_OK; + } + + SlangResult cleanUpNodes() + { + for (auto decl : _decls) + { + decl->checkState = DeclCheckState::CapabilityChecked; + } + + return SLANG_OK; + } + + + void assignGenericParameterIndices(GenericDecl* genericDecl) + { + int parameterCounter = 0; + for (auto m : genericDecl->members) + { + if (auto typeParam = as(m)) + { + typeParam->parameterIndex = parameterCounter++; + } + else if (auto valParam = as(m)) + { + valParam->parameterIndex = parameterCounter++; + } + } + } + + + void cleanUpASTNode(NodeBase* node) + { + if (auto expr = as(node)) + { + expr->checked = true; + } + else if (auto genericDecl = as(node)) + { + assignGenericParameterIndices(genericDecl); + } + else if (auto syntaxDecl = as(node)) + { + syntaxDecl->parseCallback = &parseSimpleSyntax; + syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); + } + else if (auto namespaceLikeDecl = as(node)) + { + auto declScope = _astBuilder->create(); + declScope->containerDecl = namespaceLikeDecl; + namespaceLikeDecl->ownedScope = declScope; + } + } + + void decodeASTNodeContent(NodeBase* node, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + + ASTNodeDispatcher::dispatch( + node, + [&](auto n) { _decodeDataOf(n, decoder); }); + + cleanUpASTNode(node); + } + + DeclID decodeDeclID(Decoder& decoder) + { + DeclID result = decoder.decode(); + return result; + } + + ValID decodeValID(Decoder& decoder) + { + ValID result = decoder.decode(); + return result; + } + + template + void decodeASTNode(T*& node, Decoder& decoder) + { + ASTNodeType nodeType; + auto saved = decoder.getCursor(); + { + Decoder::WithObject withObject(decoder); + decode(nodeType, decoder); + } + decoder.setCursor(saved); + + auto shell = createEmptyShell(nodeType); + decodeASTNodeContent(shell, decoder); + + node = as(shell); + } + + void decodePtr(Name*& name, Decoder& decoder, Name*) + { + String text; + decode(text, decoder); + + name = _astBuilder->getNamePool()->getName(text); + } + + void decodePtr(DeclAssociationList*& outList, Decoder& decoder, DeclAssociationList*) + { + // Mirroring the encoding logic, we decode this + // as a list of key-value pairs. + // + auto list = RefPtr(new DeclAssociationList()); + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + auto association = RefPtr(new DeclAssociation()); + + Decoder::WithKeyValuePair withPair(decoder); + decode(association->kind, decoder); + decode(association->decl, decoder); + + list->associations.add(association); + } + + outList = list.detach(); + } + + void decodePtr(DiagnosticInfo const*& info, Decoder& decoder, DiagnosticInfo const*) + { + Int id; + decode(id, decoder); + info = getDiagnosticsLookup()->getDiagnosticById(id); + } + + void decodePtr(MarkupEntry*& markupEntry, Decoder&, MarkupEntry*) + { + // TODO: is this case needed? + markupEntry = nullptr; + } + + void decodePtr(CandidateExtensionList*& list, Decoder& decoder, CandidateExtensionList*) + { + auto result = RefPtr(new CandidateExtensionList()); + decode(result->candidateExtensions, decoder); + list = result.detach(); + } + + void decodePtr(WitnessTable*& witnessTable, Decoder& decoder, WitnessTable*) + { + Decoder::WithObject withObject(decoder); + auto wt = RefPtr(new WitnessTable()); + decode(wt->baseType, decoder); + decode(wt->witnessedType, decoder); + decode(wt->isExtern, decoder); + decode(wt->m_requirementDictionary, decoder); + witnessTable = wt.detach(); + } + + void decodeValue(RequirementWitness& witness, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); + decodeEnum(witness.m_flavor, decoder); + switch (witness.m_flavor) + { + case RequirementWitness::Flavor::none: + break; + + case RequirementWitness::Flavor::declRef: + decode(witness.m_declRef, decoder); + break; + + case RequirementWitness::Flavor::val: + decode(witness.m_val, decoder); + break; + + case RequirementWitness::Flavor::witnessTable: + { + RefPtr object; + decode(object, decoder); + witness.m_obj = object; + } + break; + } + } + + template + void decodePtr(T*& node, Decoder& decoder, Val*) + { + ValID id = decodeValID(decoder); + node = static_cast(getValByID(id)); + } + + template + void decodePtr(T*& node, Decoder& decoder, Decl*) + { + DeclID id = decodeDeclID(decoder); + node = static_cast(getDeclByID(id)); + } + + template + void decodePtr(T*& node, Decoder& decoder, DeclBase*) + { + if (decoder.getTag() == SerialBinary::kInt64FourCC) + { + DeclID id = decodeDeclID(decoder); + node = static_cast(getDeclByID(id)); + } + else + { + decodeASTNode(node, decoder); + } + } + + template + void decodePtr(T*& node, Decoder& decoder, NodeBase*) + { + decodeASTNode(node, decoder); + } + + + void decodeValue(UnhandledCase, Decoder& decoder); + + void decodeValue(String& value, Decoder& decoder) { value = decoder.decodeString(); } + + void decodeValue(Token& value, Decoder& decoder) + { + decode(value.type, decoder); + decode(value.flags, decoder); + decode(value.loc, decoder); + if (decoder.decodeNull()) + { + } + else + { + Name* name = nullptr; + decode(name, decoder); + value.setName(name); + } + } + + void decodeValue(NameLoc& value, Decoder& decoder) { decode(value.name, decoder); } + + void decodeValue(SemanticVersion& value, Decoder& decoder) + { + SemanticVersion::IntegerType rawValue = decoder.decode(); + value.setFromInteger(rawValue); + } + + void decodeValue(CapabilitySet& value, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + CapabilityTargetSet targetSet; + decode(targetSet, decoder); + value.getCapabilityTargetSets()[targetSet.target] = targetSet; + } + } + + void decodeValue(CapabilityTargetSet& value, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); + decode(value.target, decoder); + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + CapabilityStageSet stageSet; + decode(stageSet, decoder); + value.shaderStageSets[stageSet.stage] = stageSet; + } + } + + void decodeValue(CapabilityStageSet& value, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); + decode(value.stage, decoder); + decode(value.atomSet, decoder); + } + + void decodeValue(CapabilityAtomSet& value, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + CapabilityAtom atom; + decode(atom, decoder); + value.add(UInt(atom)); + } + } + + template + void decodeValue(std::optional& outValue, Decoder& decoder) + { + if (decoder.decodeNull()) + { + outValue.reset(); + } + else + { + T value; + decode(value, decoder); + outValue = value; + } + } + + void decodeValue(SyntaxClass& syntaxClass, Decoder& decoder) + { + ASTNodeType nodeType; + decode(nodeType, decoder); + syntaxClass = SyntaxClass(nodeType); + } + + template + void decodeValue(DeclRef& declRef, Decoder& decoder) + { + decode(declRef.declRefBase, decoder); + } + + void decodeValue(ValNodeOperand& value, Decoder& decoder) + { + Decoder::WithKeyValuePair withPair(decoder); + + decodeEnum(value.kind, decoder); + switch (value.kind) + { + case ValNodeOperandKind::ConstantValue: + decode(value.values.intOperand, decoder); + break; + + case ValNodeOperandKind::ValNode: + { + Val* val = nullptr; + decode(val, decoder); + value.values.nodeOperand = val; + } + break; + + case ValNodeOperandKind::ASTNode: + { + Decl* decl = nullptr; + decode(decl, decoder); + value.values.nodeOperand = decl; + } + break; + } + } + + void decodeValue(TypeExp& value, Decoder& decoder) { decode(value.type, decoder); } + + void decodeValue(QualType& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.type, decoder); + decode(value.isLeftValue, decoder); + decode(value.hasReadOnlyOnTarget, decoder); + decode(value.isWriteOnly, decoder); + } + + void decodeValue(MatrixCoord& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.row, decoder); + decode(value.col, decoder); + } + + void decodeValue(SPIRVAsmOperand::Flavor& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + + void decodeValue(SPIRVAsmOperand& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.flavor, decoder); + decode(value.token, decoder); + decode(value.expr, decoder); + decode(value.bitwiseOrWith, decoder); + decode(value.knownValue, decoder); + decode(value.wrapInId, decoder); + decode(value.type, decoder); + } + + void decodeValue(SPIRVAsmInst& value, Decoder& decoder) + { + Decoder::WithObject withObject(decoder); + decode(value.opcode, decoder); + decode(value.operands, decoder); + } + + + template + void decodeEnum(T& value, Decoder& decoder) + { + value = T(decoder.decode()); + } + + template + void decodeSimpleValue(T& value, Decoder& decoder) + { + value = decoder.decode(); + } + + void decodeValue(bool& value, Decoder& decoder) { value = decoder.decodeBool(); } + void decodeValue(Int32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(Int64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(UInt32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(UInt64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(float& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + void decodeValue(double& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + + void decodeValue(uint8_t& value, Decoder& decoder) + { + value = uint8_t(decoder.decode()); + } + + void decodeValue(DeclVisibility& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(BaseType& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(BuiltinRequirementKind& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(ASTNodeType& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(ImageFormat& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(TypeTag& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(TryClauseType& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(CapabilityAtom& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(PreferRecomputeAttribute::SideEffectBehavior& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(LogicOperatorShortCircuitExpr::Flavor& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(TreatAsDifferentiableExpr::Flavor& value, Decoder& decoder) + { + decodeEnum(value, decoder); + } + void decodeValue(DeclAssociationKind& value, Decoder& decoder) { decodeEnum(value, decoder); } + void decodeValue(TokenType& value, Decoder& decoder) { decodeEnum(value, decoder); } + + + void decodeValue(SourceLoc& value, Decoder& decoder) + { + if (!decoder.decodeNull()) + { + SerialSourceLocData::SourceLoc intermediate; + decoder.decode(intermediate); + + if (_sourceLocReader) + { + auto sourceLoc = _sourceLocReader->getSourceLoc(intermediate); + value = sourceLoc; + } + } + } + + template + void decodeValue(T*& ptr, Decoder& decoder) + { + if (decoder.decodeNull()) + ptr = nullptr; + else + decodePtr(ptr, decoder, (T*)nullptr); + } + + template + void decodeValue(RefPtr& ptr, Decoder& decoder) + { + if (decoder.decodeNull()) + ptr = nullptr; + else + { + // Hi Future Tess, + // + // The next step here is decoding logic for `WitnessTable`s. + // + + decodePtr(*ptr.writeRef(), decoder, (T*)nullptr); + } + } + + void decodeValue(Modifiers& modifiers, Decoder& decoder) + { + Modifier** link = &modifiers.first; + + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Modifier* modifier = nullptr; + decode(modifier, decoder); + + *link = modifier; + link = &modifier->next; + } + } + + template + void decodeValue(ShortList& array, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + T element; + decode(element, decoder); + array.add(element); + } + } + + + template + void decode(List& array, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + T element; + decode(element, decoder); + array.add(element); + } + } + + template + void decode(T (&array)[N], Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + for (auto& element : array) + { + decode(element, decoder); + } + } + + template + void decode(OrderedDictionary& dictionary, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Decoder::WithKeyValuePair withPair(decoder); + + K key; + V value; + decode(key, decoder); + decode(value, decoder); + + dictionary.add(key, value); + } + } + + template + void decode(Dictionary& dictionary, Decoder& decoder) + { + Decoder::WithArray withArray(decoder); + while (decoder.hasElements()) + { + Decoder::WithKeyValuePair withPair(decoder); + + K key; + V value; + decode(key, decoder); + decode(value, decoder); + + dictionary.add(key, value); + } + } + + template + void decode(T& outValue, Decoder& decoder) + { + decodeValue(outValue, decoder); + } + +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do + void _decodeDataOf($T* obj, Decoder& decoder) + { +% if T.directSuperClass then + _decodeDataOf(static_cast<$(T.directSuperClass)*>(obj), decoder); +% end +% for _,f in ipairs(T.directFields) do + decode(obj->$f, decoder); +% end + } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 1 +#include "slang-serialize-ast.cpp.fiddle" +#endif // FIDDLE END +}; + +ModuleDecl* readSerializedModuleAST( + Linkage* linkage, + ASTBuilder* astBuilder, + DiagnosticSink* sink, + RiffContainer::Chunk* chunk, + SerialSourceLocReader* sourceLocReader, + SourceLoc requestingSourceLoc) +{ + ASTDecodingContext + context(linkage, astBuilder, sink, chunk, sourceLocReader, requestingSourceLoc); + context.decodeAll(); + auto node = context.getDeclByID(0); + auto moduleDecl = as(node); + return moduleDecl; +} } // namespace Slang -- cgit v1.2.3