From 9059093bc764e901a9c4aaeb12471bf32028874f Mon Sep 17 00:00:00 2001 From: Theresa Foley <10618364+tangent-vector@users.noreply.github.com> Date: Tue, 20 May 2025 21:55:39 -0700 Subject: Generalize serialization system used for AST (#7126) This change takes the new approach to serialization that was used for the AST and generalizes it in a few ways: * The new approach is no longer tangled up with the RIFF format. The serialization system supports multiple different implementations of the underlying format. The existing RIFF format is now supported as one back-end, but support for others will follow in subsequent changes. * The new approach is no longer deeply specialized to AST serialization. The old code had things like serialization for `List`s and `Dictionary`s, but it was embedded inside the `AST{Encoding|Decoding}Context`, and thus couldn't be leveraged for other serialization tasks. This change factors out a completely AST-independent `Serializer` implementation, with an `ASTSerializer` layered on top of it to provide the additional context needed. * There is less duplication of code between reading and writing of serialized data. The old code had both the `ASTEncodingContext` and `ASTDecodingContext`, with serialization logic for most types being implemented in both, but with the constraint that those implementations needed to be kept in sync to avoid serialization-related runtime failures. A key property of the revamped approach is that a single `serialize()` method for a type implements both the reading and writing directions of serialization. --- source/slang/slang-serialize-ast.cpp | 1990 ++++++++++++---------------------- 1 file changed, 682 insertions(+), 1308 deletions(-) (limited to 'source/slang/slang-serialize-ast.cpp') diff --git a/source/slang/slang-serialize-ast.cpp b/source/slang/slang-serialize-ast.cpp index 1ef532ad1..9a61dbc5a 100644 --- a/source/slang/slang-serialize-ast.cpp +++ b/source/slang/slang-serialize-ast.cpp @@ -5,6 +5,7 @@ #include "slang-compiler.h" #include "slang-diagnostics.h" #include "slang-mangle.h" +#include "slang-serialize-riff.h" namespace Slang { @@ -13,1528 +14,900 @@ namespace Slang // NodeBase* parseSimpleSyntax(Parser* parser, void* userData); +// +// Many of the types used in the AST can be serialized using +// just the `Serializer` type, so we will handle all of those first. +// -struct ASTEncodingContext +void serialize(Serializer const& serializer, ASTNodeType& value) { -private: - Encoder* encoder; - struct UnhandledCase - { - }; - - typedef Int DeclID; - Dictionary mapDeclToID; - List decls; - - struct ImportedDeclInfo - { - Int moduleIndex = -1; - Decl* decl; - }; - List importedDecls; - - typedef Int ValID; - Dictionary mapValToID; - List vals; - - ModuleDecl* _module = nullptr; - - SerialSourceLocWriter* _sourceLocWriter = nullptr; - -public: - ASTEncodingContext(Encoder* encoder, ModuleDecl* module, SerialSourceLocWriter* sourceLocWriter) - : encoder(encoder), _module(module), _sourceLocWriter(sourceLocWriter) - { - } - - template - void encodeASTNodeContent(T* node) - { - Encoder::WithObject withObject(encoder); - - ASTNodeDispatcher::dispatch(node, [&](auto n) { _encodeDataOf(n); }); - } - - void flush() - { - auto containerChunk = encoder->getRIFFChunk(); - - RIFF::ChunkBuilder* declChunk = nullptr; - RIFF::ChunkBuilder* importedDeclChunk = nullptr; - RIFF::ChunkBuilder* valChunk = nullptr; - { - Encoder::WithArray withList(encoder); - declChunk = encoder->getRIFFChunk(); - } - { - Encoder::WithArray withList(encoder); - importedDeclChunk = encoder->getRIFFChunk(); - } - { - Encoder::WithArray withList(encoder); - valChunk = encoder->getRIFFChunk(); - } - Int declIndex = 0; - Int importedDeclIndex = 0; - Int valIndex = 0; + serializeEnum(serializer, value); +} - bool done = false; - do - { - done = true; - while (declIndex < decls.getCount()) - { - done = false; - encoder->setRIFFChunk(declChunk); - encodeASTNodeContent(decls[declIndex++]); - } - while (importedDeclIndex < importedDecls.getCount()) - { - done = false; - encoder->setRIFFChunk(importedDeclChunk); - encodeImportedDecl(importedDecls[importedDeclIndex++]); - } - while (valIndex < vals.getCount()) - { - done = false; - encoder->setRIFFChunk(valChunk); - encodeASTNodeContent(vals[valIndex++]); - } - } while (!done); +void serialize(Serializer const& serializer, TypeTag& value) +{ + serializeEnum(serializer, value); +} - encoder->setRIFFChunk(containerChunk); - } +void serialize(Serializer const& serializer, BaseType& value) +{ + serializeEnum(serializer, value); +} - ModuleDecl* findModuleForDecl(Decl* decl) - { - for (auto d = decl; d; d = d->parentDecl) - { - if (auto m = as(d)) - return m; - } - return nullptr; - } +void serialize(Serializer const& serializer, TryClauseType& value) +{ + serializeEnum(serializer, value); +} - ModuleDecl* findModuleDeclWasImportedFrom(Decl* decl) - { - auto declModule = findModuleForDecl(decl); - if (declModule == nullptr) - return nullptr; - if (declModule == _module) - return nullptr; - return declModule; - } +void serialize(Serializer const& serializer, DeclVisibility& value) +{ + serializeEnum(serializer, value); +} - DeclID getDeclID(Decl* decl) - { - SLANG_ASSERT(decl != nullptr); +void serialize(Serializer const& serializer, BuiltinRequirementKind& value) +{ + serializeEnum(serializer, value); +} - if (auto found = mapDeclToID.tryGetValue(decl)) - return *found; +void serialize(Serializer const& serializer, ImageFormat& value) +{ + serializeEnum(serializer, value); +} - // We need to detect whether the declaration is an - // imported one, or one from this module itself. - // - // Imported declarations need to be handled very - // differently, since they'll involve resolving - // references to those other modules, and the - // declarations within them. - // - if (auto importedFromModule = findModuleDeclWasImportedFrom(decl)) - { - DeclID importedFromModuleDeclID = 0; - if (decl != importedFromModule) - { - importedFromModuleDeclID = getDeclID(importedFromModule); - } +void serialize(Serializer const& serializer, PreferRecomputeAttribute::SideEffectBehavior& value) +{ + serializeEnum(serializer, value); +} - DeclID id = ~importedDecls.getCount(); - mapDeclToID.add(decl, id); +void serialize(Serializer const& serializer, TreatAsDifferentiableExpr::Flavor& value) +{ + serializeEnum(serializer, value); +} - ImportedDeclInfo info; - info.moduleIndex = ~importedFromModuleDeclID; - info.decl = decl; - importedDecls.add(info); +void serialize(Serializer const& serializer, LogicOperatorShortCircuitExpr::Flavor& value) +{ + serializeEnum(serializer, value); +} - return id; - } - else - { - DeclID id = decls.getCount(); - decls.add(decl); - mapDeclToID.add(decl, id); +void serialize(Serializer const& serializer, RequirementWitness::Flavor& value) +{ + serializeEnum(serializer, value); +} - return id; - } - } +void serialize(Serializer const& serializer, CapabilityAtom& value) +{ + serializeEnum(serializer, value); +} - void encodePtr(Decl* decl) - { - DeclID id = getDeclID(decl); - encoder->encode(id); - } +void serialize(Serializer const& serializer, DeclAssociationKind& value) +{ + serializeEnum(serializer, value); +} - ValID getValID(Val* val) - { - SLANG_ASSERT(val != nullptr); +void serialize(Serializer const& serializer, TokenType& value) +{ + serializeEnum(serializer, value); +} - if (auto found = mapValToID.tryGetValue(val)) - return *found; +void serialize(Serializer const& serializer, ValNodeOperandKind& value) +{ + serializeEnum(serializer, value); +} - // In order to ensure that values can be fully constructed - // from the get-go (so that they will get cached correctly), - // we conspire to ensure that every value is preceded by - // all of its operands. - // - for (auto operand : val->m_operands) - { - switch (operand.kind) - { - default: - break; - - case ValNodeOperandKind::ValNode: - if (auto operandNode = operand.values.nodeOperand) - { - SLANG_ASSERT(as(operandNode)); - getValID(static_cast(operandNode)); - } - break; - - case ValNodeOperandKind::ASTNode: - if (auto operandNode = operand.values.nodeOperand) - { - SLANG_ASSERT(as(operandNode)); - getDeclID(static_cast(operandNode)); - } - break; - } - } - auto resolved = val->resolve(); - if (resolved != val) - { - getValID(resolved); - } +void serialize(Serializer const& serializer, SPIRVAsmOperand::Flavor& value) +{ + serializeEnum(serializer, value); +} - ValID id = vals.getCount(); - vals.add(val); - mapValToID.add(val, id); - return id; - } +void serialize(Serializer const& serializer, MatrixCoord& value) +{ + SLANG_SCOPED_SERIALIZER_TUPLE(serializer); + serialize(serializer, value.row); + serialize(serializer, value.col); +} - void encodePtr(Val* val) +void serializePtr(Serializer const& serializer, DiagnosticInfo const*& value, DiagnosticInfo const*) +{ + Int32 id = 0; + if (isWriting(serializer)) { - ValID id = getValID(val); - encoder->encode(id); + id = value->id; + serialize(serializer, id); } - - void encodeImportedDecl(ImportedDeclInfo const& info) + else { - Encoder::WithKeyValuePair withPair(encoder); - encode(info.moduleIndex); - auto decl = info.decl; - if (auto importedModuleDecl = as(decl)) - { - SLANG_ASSERT(info.moduleIndex == -1); - encode(importedModuleDecl->getName()); - } - else - { - auto mangledName = getMangledName(getCurrentASTBuilder(), decl); - encode(mangledName); - } + serialize(serializer, id); + value = getDiagnosticsLookup()->getDiagnosticById(id); } +} - void encodePtr(Modifier* modifier) { encodeASTNodeContent(modifier); } - void encodePtr(Expr* expr) { encodeASTNodeContent(expr); } - void encodePtr(Stmt* stmt) { encodeASTNodeContent(stmt); } - - void encodePtr(Name* name) { encode(name->text); } +void serialize(Serializer const& serializer, SemanticVersion& value) +{ + auto raw = value.getRawValue(); + serialize(serializer, raw); + value = SemanticVersion::fromRaw(raw); +} - void encodePtr(MarkupEntry* entry) +void serialize(Serializer const& serializer, SyntaxClass& value) +{ + ASTNodeType raw; + if (isWriting(serializer)) { - // TODO: is this case needed? - SLANG_UNUSED(entry); + raw = value.getTag(); } - - void encodePtr(DeclAssociationList* list) + serialize(serializer, raw); + if (isReading(serializer)) { - // We serialize this as if it were a simple list - // of key-value pairs because... well... that's - // what it amounts to in practice. - // - Encoder::WithArray withArray(encoder); - for (auto association : list->associations) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(association->kind); - encode(association->decl); - } + value = SyntaxClass(raw); } +} - void encodePtr(CandidateExtensionList* list) { encode(list->candidateExtensions); } - - void encodePtr(WitnessTable* witnessTable) - { - Encoder::WithObject withObject(encoder); - encode(witnessTable->baseType); - encode(witnessTable->witnessedType); - encode(witnessTable->isExtern); - - // TODO(tfoley): In theory we should be able to streamline - // this so that we only encode the requirements that we - // absolutely need to (which basically amounts to `associatedtype` - // requirements where the satisfying type is part of the public - // API of the type). - // - encode(witnessTable->m_requirementDictionary); - } +// +// Many types in the AST need additional context (beyond +// what the `Serializer` has) in order to serialize +// themselves or their members. +// +// We define a custom serializer interface to capture +// the cases that can't be handled by a `Serializer` +// alone. +// - void encodeValue(RequirementWitness const& witness) - { - Encoder::WithKeyValuePair withPair(encoder); - encodeEnum(witness.m_flavor); - switch (witness.m_flavor) - { - case RequirementWitness::Flavor::none: - break; +/// Interface for AST serialization +struct ASTSerializerImpl +{ +public: + virtual void handleASTNode(NodeBase*& value) = 0; + virtual void handleASTNodeContents(NodeBase* value) = 0; + virtual void handleName(Name*& value) = 0; + virtual void handleSourceLoc(SourceLoc& value) = 0; + virtual void handleToken(Token& value) = 0; - case RequirementWitness::Flavor::declRef: - encode(witness.m_declRef); - break; + // Note that this type does *not* inherit from `ISerializerImpl`. + // + // We want to decouple the AST-specific context information + // from the lower-level details of the serialization format. + // + // Instead of using inheritance, we expect that any + // `ASTSerializerImpl` will aggregate a lower-level + // serializer, and the interface exposes access to + // that base serializer implementation. - case RequirementWitness::Flavor::val: - encode(witness.m_val); - break; + virtual ISerializerImpl* getBaseSerializer() = 0; +}; - case RequirementWitness::Flavor::witnessTable: - encode((WitnessTable*)witness.m_obj.Ptr()); - break; - } - } +/// Specialization of `Serializer_` for AST serialization. +template<> +struct Serializer_ : SerializerBase +{ +public: + using SerializerBase::SerializerBase; - void encodePtr(DiagnosticInfo* info) { encode(Int(info->id)); } + // + // In order to allow an `ASTSerializer` to be used with + // functions that expect an ordinary `Serializer`, we + // implement an implicit conversion operator. + // - void encodePtr(DeclBase* declBase) - { - if (auto decl = as(declBase)) - { - encodePtr(decl); - } - else - { - encodeASTNodeContent(declBase); - } - } + operator Serializer() const { return Serializer(get()->getBaseSerializer()); } +}; - void encodeValue(UnhandledCase); +/// Context type for AST serialization. +using ASTSerializer = Serializer_; - void encodeValue(String const& value) { encoder->encode(value); } +template +void serializeObject(ASTSerializer const& serializer, T*& value, NodeBase*) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serializer->handleASTNode(*(NodeBase**)&value); +} - void encodeValue(Token const& value) - { - encode(value.type); - encode(TokenFlags(value.flags & ~TokenFlag::Name)); - encode(value.loc); - if (value.hasContent()) - encoder->encodeString(value.getContent()); - else - encode(nullptr); - } +void serializeObjectContents(ASTSerializer const& serializer, NodeBase* value, NodeBase*) +{ + serializer->handleASTNodeContents(value); +} - void encodeValue(NameLoc const& value) { encode(value.name); } +template +void serialize(ASTSerializer const& serializer, DeclRef& value) +{ + serialize(serializer, value.declRefBase); +} - void encodeValue(SemanticVersion value) { encoder->encode(value.getRawValue()); } +void serialize(ASTSerializer const& serializer, SourceLoc& value) +{ + serializer->handleSourceLoc(value); +} - void encodeValue(CapabilitySet const& value) +void serialize(ASTSerializer const& serializer, RequirementWitness& value) +{ + SLANG_SCOPED_SERIALIZER_TAGGED_UNION(serializer); + serialize(serializer, value.m_flavor); + switch (value.m_flavor) { - // While the `CapabilityTargetSets` type is a dictionary, - // in practice each entry already embeds its own key - // (the target atom), so we can encode this as just - // an array of the `CapabilityTargetSet` values. - // - Encoder::WithArray withArray(encoder); - for (auto pair : value.getCapabilityTargetSets()) - { - encode(pair.second); - } - } + case RequirementWitness::Flavor::none: + break; - void encodeValue(CapabilityTargetSet const& value) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(value.target); + case RequirementWitness::Flavor::declRef: + serialize(serializer, value.m_declRef); + break; - // Similar to the case for the `CapabilityTargetSets` above, - // each `CapabilityStageSet` already includes the stage atom, - // so we can simply encode the values from the dictionary. - // - Encoder::WithArray withArray(encoder); - for (auto pair : value.shaderStageSets) - { - encode(pair.second); - } - } + case RequirementWitness::Flavor::val: + serialize(serializer, value.m_val); + break; - void encodeValue(CapabilityStageSet const& value) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(value.stage); - encode(value.atomSet); + case RequirementWitness::Flavor::witnessTable: + serialize(serializer, value.m_obj); + break; } +} - void encodeValue(CapabilityAtomSet const& value) +void serialize(ASTSerializer const& serializer, WitnessTable& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.baseType); + serialize(serializer, value.witnessedType); + serialize(serializer, value.isExtern); + + // TODO(tfoley): In theory we should be able to streamline + // this so that we only encode the requirements that we + // absolutely need to (which basically amounts to `associatedtype` + // requirements where the satisfying type is part of the public + // API of the type). + // + serialize(serializer, value.m_requirementDictionary); +} + +void serialize(Serializer const& serializer, CapabilityAtomSet& value) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) { - Encoder::WithArray withArray(encoder); for (auto rawAtom : value) { - encode(CapabilityAtom(rawAtom)); + auto atom = CapabilityAtom(rawAtom); + serialize(serializer, atom); } } - - template - void encodeValue(std::optional const& value) - { - if (value) - encodeValue(*value); - else - encoder->encode(nullptr); - } - - void encodeValue(SyntaxClass const& value) { encode(value.getTag()); } - - template - void encodeValue(DeclRef const& value) + else { - encode((DeclRefBase*)value); - } - - void encodeValue(ValNodeOperand value) - { - Encoder::WithKeyValuePair withPair(encoder); - - encodeEnum(value.kind); - switch (value.kind) + while (hasElements(serializer)) { - case ValNodeOperandKind::ConstantValue: - encode(value.values.intOperand); - break; - - case ValNodeOperandKind::ValNode: - encode(static_cast(value.values.nodeOperand)); - break; - - case ValNodeOperandKind::ASTNode: - { - if (auto decl = as(value.values.nodeOperand)) - { - encode(decl); - } - else - { - SLANG_UNEXPECTED("AST node operand of `Val` was expected to be a `Decl`"); - } - } - break; + CapabilityAtom atom; + serialize(serializer, atom); + value.add(UInt(atom)); } } +} - void encodeValue(TypeExp value) { encode(value.type); } - - void encodeValue(QualType value) - { - Encoder::WithObject withObject(encoder); - encode(value.type); - encode(value.isLeftValue); - encode(value.hasReadOnlyOnTarget); - encode(value.isWriteOnly); - } - - void encodeValue(MatrixCoord value) - { - Encoder::WithObject withObject(encoder); - encode(value.row); - encode(value.col); - } - - void encodeValue(SPIRVAsmOperand::Flavor const& value) { encodeEnum(value); } - - void encodeValue(SPIRVAsmOperand const& value) - { - Encoder::WithObject withObject(encoder); - encode(value.flavor); - encode(value.token); - encode(value.expr); - encode(value.bitwiseOrWith); - encode(value.knownValue); - encode(value.wrapInId); - encode(value.type); - } +void serialize(Serializer const& serializer, CapabilityStageSet& value) +{ + serialize(serializer, value.atomSet); +} - void encodeValue(SPIRVAsmInst const& value) +void serialize(Serializer const& serializer, CapabilityTargetSet& value) +{ + serialize(serializer, value.shaderStageSets); + + // The value for each entry in `shaderStageSets` have + // a `stage` field that is redundant with the key for + // that entry. Rather than serialize the key as part + // of the `CapabilityStageSet` type, we instead copy + // it over from the key to the value in the case where + // we are reading. + // + if (isReading(serializer)) { - Encoder::WithObject withObject(encoder); - encode(value.opcode); - encode(value.operands); + for (auto& p : value.shaderStageSets) + p.second.stage = p.first; } +} - - template>> - void encodeValue(T value) +void serialize(Serializer const& serializer, CapabilitySet& value) +{ + serialize(serializer, value.getCapabilityTargetSets()); + + // The value for each entry in `getCapabilityTargetSets()` have + // a `target` field that is redundant with the key for + // that entry. Rather than serialize the key as part + // of the `CapabilityTargetSet` type, we instead copy + // it over from the key to the value in the case where + // we are reading. + // + if (isReading(serializer)) { - encoder->encodeBool(value); + for (auto& p : value.getCapabilityTargetSets()) + p.second.target = p.first; } +} - void encodeValue(Int32 value) { encoder->encode(value); } - void encodeValue(UInt32 value) { encoder->encode(value); } - void encodeValue(Int64 value) { encoder->encode(value); } - void encodeValue(UInt64 value) { encoder->encode(value); } - void encodeValue(float value) { encoder->encode(value); } - void encodeValue(double value) { encoder->encode(value); } - - void encodeValue(uint8_t value) { encoder->encode(UInt32(value)); } +void serialize(ASTSerializer const& serializer, CandidateExtensionList& value) +{ + serialize(serializer, value.candidateExtensions); +} - void encodeValue(std::nullptr_t) { encoder->encode(nullptr); } +void serialize(ASTSerializer const& serializer, DeclAssociation& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.kind); + serialize(serializer, value.decl); +} - template - void encodeEnum(T value) - { - encoder->encode(Int32(value)); - } +void serialize(ASTSerializer const& serializer, DeclAssociationList& value) +{ + serialize(serializer, value.associations); +} - void encodeValue(DeclVisibility value) { encodeEnum(value); } - void encodeValue(BaseType value) { encodeEnum(value); } - void encodeValue(BuiltinRequirementKind value) { encodeEnum(value); } - void encodeValue(ASTNodeType value) { encodeEnum(value); } - void encodeValue(ImageFormat value) { encodeEnum(value); } - void encodeValue(TypeTag value) { encodeEnum(value); } - void encodeValue(TryClauseType value) { encodeEnum(value); } - void encodeValue(CapabilityAtom value) { encodeEnum(value); } - void encodeValue(DeclAssociationKind value) { encodeEnum(value); } - void encodeValue(TokenType value) { encodeEnum(value); } - - void encodeValue(SourceLoc value) +void serialize(ASTSerializer const& serializer, Modifiers& value) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) { - if (!_sourceLocWriter) - { - encoder->encode(nullptr); - } - else + for (auto modifier : value) { - auto intermediate = _sourceLocWriter->addSourceLoc(value); - encoder->encode(intermediate); + serialize(serializer, modifier); } } - - template - void encodeValue(T const* ptr) + else { - if (!ptr) - { - encoder->encode(nullptr); - } - else - { - encodePtr(const_cast(ptr)); - } - } + Modifier** link = &value.first; - template - void encodeValue(RefPtr const& ptr) - { - if (!ptr) + while (hasElements(serializer)) { - encoder->encode(nullptr); - } - else - { - encodePtr(ptr.Ptr()); - } - } + Modifier* modifier = nullptr; + serialize(serializer, modifier); - void encodeValue(Modifiers const& modifiers) - { - Encoder::WithArray withArray(encoder); - for (auto m : const_cast(modifiers)) - { - encode(m); + *link = modifier; + link = &modifier->next; } } +} - template - void encodeValue(ShortList const& array) - { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } - } +void serialize(ASTSerializer const& serializer, TypeExp& value) +{ + serialize(serializer, value.type); +} +void serialize(ASTSerializer const& serializer, QualType& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.isLeftValue); + serialize(serializer, value.hasReadOnlyOnTarget); + serialize(serializer, value.isWriteOnly); +} - template - void encode(List const& array) - { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } - } +void serialize(ASTSerializer const& serializer, Token& value) +{ + serializer->handleToken(value); +} - template - void encode(T const (&array)[N]) - { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } - } +void serialize(ASTSerializer const& serializer, SPIRVAsmOperand& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.flavor); + serialize(serializer, value.token); + serialize(serializer, value.expr); + serialize(serializer, value.bitwiseOrWith); + serialize(serializer, value.knownValue); + serialize(serializer, value.wrapInId); + serialize(serializer, value.type); +} - template - void encode(OrderedDictionary const& dictionary) - { - Encoder::WithArray withArray(encoder); - for (auto p : dictionary) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(p.key); - encode(p.value); - } - } +void serialize(ASTSerializer const& serializer, SPIRVAsmInst& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.opcode); + serialize(serializer, value.operands); +} - template - void encode(Dictionary const& dictionary) +void serialize(ASTSerializer const& serializer, ValNodeOperand& value) +{ + SLANG_SCOPED_SERIALIZER_TAGGED_UNION(serializer); + serialize(serializer, value.kind); + switch (value.kind) { - Encoder::WithArray withArray(encoder); - for (auto p : dictionary) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(p.first); - encode(p.second); - } - } + case ValNodeOperandKind::ConstantValue: + serialize(serializer, value.values.intOperand); + break; - template - void encode(T const& value) - { - encodeValue(value); + case ValNodeOperandKind::ValNode: + case ValNodeOperandKind::ASTNode: + serialize(serializer, value.values.nodeOperand); + break; } +} - // for each class of node, we generate - // code to recursively serialize each - // of its fields. +void serializeObject(ASTSerializer const& serializer, Name*& value, Name*) +{ + serializer->handleName(value); +} + +void serialize(ASTSerializer const& serializer, NameLoc& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.name); + serialize(serializer, value.loc); +} #if 0 // FIDDLE TEMPLATE: %for _,T in ipairs(Slang.NodeBase.subclasses) do - void _encodeDataOf($T* obj) - { -%if T.directSuperClass then - _encodeDataOf(static_cast<$(T.directSuperClass)*>(obj)); -%end -%for _,f in ipairs(T.directFields) do - encode(obj->$f); -%end - } +void _serializeASTNodeContents(ASTSerializer const& serializer, $T* value) +{ + SLANG_UNUSED(serializer); + SLANG_UNUSED(value); +% if T.directSuperClass then + _serializeASTNodeContents(serializer, static_cast<$(T.directSuperClass)*>(value)); +% end +% for _,f in ipairs(T.directFields) do + serialize(serializer, value->$f); +% end + } %end #else // FIDDLE OUTPUT: #define FIDDLE_GENERATED_OUTPUT_ID 0 #include "slang-serialize-ast.cpp.fiddle" #endif // FIDDLE END -}; -void writeSerializedModuleAST( - Encoder* encoder, - ModuleDecl* moduleDecl, - SerialSourceLocWriter* sourceLocWriter) +void serializeASTNodeContents(ASTSerializer const& serializer, NodeBase* node) { - Encoder::WithObject withObject(encoder); - - // TODO: we should have a more careful pass here, - // where we only encode the public declarations - // - - ASTEncodingContext context(encoder, moduleDecl, sourceLocWriter); - context.getDeclID(moduleDecl); - context.flush(); -} - -struct ASTDecodingContext -{ -public: - ASTDecodingContext( - Linkage* linkage, - ASTBuilder* astBuilder, - DiagnosticSink* sink, - RIFF::Chunk const* baseChunk, - SerialSourceLocReader* sourceLocReader, - SourceLoc requestingSourceLoc) - : _linkage(linkage) - , _astBuilder(astBuilder) - , _sink(sink) - , _baseChunk(as(baseChunk)) - , _sourceLocReader(sourceLocReader) - , _requestingSourceLoc(requestingSourceLoc) - { - } - - Linkage* _linkage = nullptr; - DiagnosticSink* _sink = nullptr; - SerialSourceLocReader* _sourceLocReader = nullptr; - SourceLoc _requestingSourceLoc; - - SlangResult decodeAll() - { - auto cursor = _baseChunk->getChildren().begin(); - - // There are a few different top-level chunks that - // hold different arrays that we need in order - // to decode the entire module hierarchy. - // - // Basically, these lists correspond to the kinds - // of nodes in the AST hierarchy for which back-references - // are allowed (all other nodes should, barring - // weird corner cases, form a single tree-structured - // ownership hierarchy, rooted at the `ModuleDecl`. - // - - // First there is the list that actually encodes - // for the declarations in the module, including - // the `ModuleDecl` itself, which should be the - // first entry in the list. - // - auto declChunk = *cursor; - ++cursor; - - // Next there is a list of all the declarations - // referenced inside of the module that need to - // be imported in from outside. - // - auto importedDeclChunk = *cursor; - ++cursor; - - // Then there are all the `Val`-derived nodes that - // are needed by the module, which will need to be - // deduplicated so that they are unique within the - // current compilation context. - // - auto valChunk = *cursor; - ++cursor; - - // The process of decoding the module is then spread - // over a number of steps. - // - // The first step is to process all of the imported - // declarations, so that other nodes can refer to - // them. - // - SLANG_RETURN_ON_FAIL(decodeImportedDecls(importedDeclChunk)); - - // Next we process the declarations that are within - // the module itself, first creating an "empty shell" - // of each declaration that has the right size in - // memory (and the right `ASTNodeType` tag), so that - // we can wire up references to it (including circular - // references)... so long as nothing here tries to - // look *inside* the empty shell along the way. - // - SLANG_RETURN_ON_FAIL(createEmptyShells(declChunk)); - - // Once all the `Decl`s that might be needed have - // been allocated, we can process all the `Val`s - // that might reference those`Decl`s (and one another). - // - // The nature of the `Val` representation ensures - // that there cannot be cirularities in the references - // between `Val`s, and the encoding process will have - // sorted the entries so that a `Val` only ever appears - // *after* its operands. - // - SLANG_RETURN_ON_FAIL(decodeVals(valChunk)); - - // Once all the back-reference-able objects have been - // instantiated in memory, we can go back through the - // `Decl`s in the module and fill in those empty shells. - // - SLANG_RETURN_ON_FAIL(fillEmptyShells(declChunk)); - - // As a final pass, we perform any special cleanup actions - // that might be required to make the output valid for consumers. - // - // For example, this is where we set the `DeclCheckState` of everything - // we are loading to reflect the fact that everything we deserialize - // is (supposed to be) fully cheked. - // - SLANG_RETURN_ON_FAIL(cleanUpNodes()); - - - return SLANG_OK; - } - - typedef Int DeclID; - Decl* getDeclByID(DeclID id) - { - if (id >= 0) - { - return _decls[id]; - } - else - { - return _importedDecls[~id]; - } - } - -private: - struct UnhandledCase - { - }; - - ASTBuilder* _astBuilder = nullptr; - RIFF::ListChunk const* _baseChunk = nullptr; - - List _decls; - List _importedDecls; - List _vals; - - typedef Int ValID; - Val* getValByID(ValID id) { return _vals[id]; } - - SlangResult decodeImportedDecls(RIFF::Chunk const* importedDeclChunk) - { - Decoder decoder(importedDeclChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Decoder::WithKeyValuePair withPair(decoder); - - Int moduleIndex; - decode(moduleIndex, decoder); - - if (moduleIndex == -1) - { - Name* moduleName = nullptr; - decode(moduleName, decoder); - - Decl* importedModule = getImportedModule(moduleName); - _importedDecls.add(importedModule); - } - else - { - auto importedFromModuleDecl = as(_importedDecls[moduleIndex]); - auto importedFromModule = importedFromModuleDecl->module; - - String mangledName; - decode(mangledName, decoder); - - auto importedNode = - importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); - auto importedDecl = as(importedNode); - _importedDecls.add(importedDecl); - } - } - return SLANG_OK; - } - - ModuleDecl* getImportedModule(Name* moduleName) - { - Module* module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); - if (!module) - { - SLANG_ABORT_COMPILATION("failed to load an imported module during deserialization"); - } - - return module->getModuleDecl(); - } - - SlangResult decodeVals(RIFF::Chunk const* valChunk) - { - Decoder decoder(valChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Val* val = decodeValNode(decoder); - _vals.add(val); - } - return SLANG_OK; - } - - SlangResult createEmptyShells(RIFF::Chunk const* declChunk) - { - Decoder decoder(declChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - ASTNodeType nodeType; - - // Each of the declarations is expected to take - // the form of an object with a first field - // that holds the node type. - // - { - Decoder::WithObject withObject(decoder); - decode(nodeType, decoder); - } - - auto emptyShell = createEmptyShell(nodeType); - auto declEmptyShell = as(emptyShell); - _decls.add(declEmptyShell); - } - - return SLANG_OK; - } - - Val* decodeValNode(Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - - ASTNodeType nodeType; - decode(nodeType, decoder); - - ValNodeDesc desc; - desc.type = SyntaxClass(nodeType); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - ValNodeOperand operand; - decode(operand, decoder); - desc.operands.add(operand); - } - - desc.init(); - - auto val = _astBuilder->_getOrCreateImpl(_Move(desc)); - - // Values created during deserialization are - // not expected to ever resolve further, because - // they should be coming from fully checked code. - // - // val->resolve(); - // val->_setUnique(); - - return val; - } - - NodeBase* createEmptyShell(ASTNodeType nodeType) - { - return SyntaxClass(nodeType).createInstance(_astBuilder); - } - - SlangResult fillEmptyShells(RIFF::Chunk const* declChunk) - { - Index declIndex = 0; - - Decoder decoder(declChunk); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - auto declEmptyShell = _decls[declIndex++]; - decodeASTNodeContent(declEmptyShell, decoder); - } - - return SLANG_OK; - } - - SlangResult cleanUpNodes() - { - for (auto decl : _decls) - { - decl->checkState = DeclCheckState::CapabilityChecked; - } - - return SLANG_OK; - } - - - void assignGenericParameterIndices(GenericDecl* genericDecl) - { - int parameterCounter = 0; - for (auto m : genericDecl->members) - { - if (auto typeParam = as(m)) - { - typeParam->parameterIndex = parameterCounter++; - } - else if (auto valParam = as(m)) - { - valParam->parameterIndex = parameterCounter++; - } - } - } - - - void cleanUpASTNode(NodeBase* node) - { - if (auto expr = as(node)) - { - expr->checked = true; - } - else if (auto genericDecl = as(node)) - { - assignGenericParameterIndices(genericDecl); - } - else if (auto syntaxDecl = as(node)) - { - syntaxDecl->parseCallback = &parseSimpleSyntax; - syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); - } - else if (auto namespaceLikeDecl = as(node)) - { - auto declScope = _astBuilder->create(); - declScope->containerDecl = namespaceLikeDecl; - namespaceLikeDecl->ownedScope = declScope; - } - } - - void decodeASTNodeContent(NodeBase* node, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); + ASTNodeDispatcher::dispatch( + node, + [&](auto n) { _serializeASTNodeContents(serializer, n); }); +} - ASTNodeDispatcher::dispatch( - node, - [&](auto n) { _decodeDataOf(n, decoder); }); +enum class PseudoASTNodeType +{ + None, + ImportedModule, + ImportedDecl, +}; - cleanUpASTNode(node); - } +static PseudoASTNodeType _getPseudoASTNodeType(ASTNodeType type) +{ + return int(type) < 0 ? PseudoASTNodeType(~int(type)) : PseudoASTNodeType::None; +} - DeclID decodeDeclID(Decoder& decoder) - { - DeclID result = decoder.decode(); - return result; - } +static ASTNodeType _getAsASTNodeType(PseudoASTNodeType type) +{ + return ASTNodeType(~int(type)); +} - ValID decodeValID(Decoder& decoder) +struct ASTEncodingContext : ASTSerializerImpl +{ +public: + ASTEncodingContext( + RIFF::BuildCursor& cursor, + ModuleDecl* module, + SerialSourceLocWriter* sourceLocWriter) + : _writer(cursor.getCurrentChunk()), _module(module), _sourceLocWriter(sourceLocWriter) { - ValID result = decoder.decode(); - return result; } - template - void decodeASTNode(T*& node, Decoder& decoder) - { - ASTNodeType nodeType; - auto saved = decoder.getCursor(); - { - Decoder::WithObject withObject(decoder); - decode(nodeType, decoder); - } - decoder.setCursor(saved); - - auto shell = createEmptyShell(nodeType); - decodeASTNodeContent(shell, decoder); +private: + RIFFSerialWriter _writer; + ModuleDecl* _module = nullptr; + SerialSourceLocWriter* _sourceLocWriter = nullptr; - node = as(shell); - } + virtual ISerializerImpl* getBaseSerializer() override { return &_writer; } - void decodePtr(Name*& name, Decoder& decoder, Name*) - { - String text; - decode(text, decoder); + virtual void handleName(Name*& value) override; + virtual void handleSourceLoc(SourceLoc& value) override; + virtual void handleToken(Token& value) override; + virtual void handleASTNode(NodeBase*& node) override; + virtual void handleASTNodeContents(NodeBase* node) override; - name = _astBuilder->getNamePool()->getName(text); - } + void _writeImportedModule(ModuleDecl* moduleDecl); + void _writeImportedDecl(Decl* decl, ModuleDecl* importedFromModuleDecl); - void decodePtr(DeclAssociationList*& outList, Decoder& decoder, DeclAssociationList*) + ModuleDecl* _findModuleForDecl(Decl* decl) { - // Mirroring the encoding logic, we decode this - // as a list of key-value pairs. - // - auto list = RefPtr(new DeclAssociationList()); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) + for (auto d = decl; d; d = d->parentDecl) { - auto association = RefPtr(new DeclAssociation()); - - Decoder::WithKeyValuePair withPair(decoder); - decode(association->kind, decoder); - decode(association->decl, decoder); - - list->associations.add(association); + if (auto m = as(d)) + return m; } - - outList = list.detach(); + return nullptr; } - void decodePtr(DiagnosticInfo const*& info, Decoder& decoder, DiagnosticInfo const*) + ModuleDecl* _findModuleDeclWasImportedFrom(Decl* decl) { - Int id; - decode(id, decoder); - info = getDiagnosticsLookup()->getDiagnosticById(id); + auto declModule = _findModuleForDecl(decl); + if (declModule == nullptr) + return nullptr; + if (declModule == _module) + return nullptr; + return declModule; } +}; - void decodePtr(MarkupEntry*& markupEntry, Decoder&, MarkupEntry*) +struct ASTDecodingContext : ASTSerializerImpl +{ +public: + ASTDecodingContext( + Linkage* linkage, + ASTBuilder* astBuilder, + DiagnosticSink* sink, + RIFF::Chunk const* baseChunk, + SerialSourceLocReader* sourceLocReader, + SourceLoc requestingSourceLoc) + : _linkage(linkage) + , _astBuilder(astBuilder) + , _sink(sink) + , _sourceLocReader(sourceLocReader) + , _requestingSourceLoc(requestingSourceLoc) + , _riffReader(baseChunk) { - // TODO: is this case needed? - markupEntry = nullptr; } - void decodePtr(CandidateExtensionList*& list, Decoder& decoder, CandidateExtensionList*) - { - auto result = RefPtr(new CandidateExtensionList()); - decode(result->candidateExtensions, decoder); - list = result.detach(); - } +private: + Linkage* _linkage = nullptr; + ASTBuilder* _astBuilder = nullptr; + DiagnosticSink* _sink = nullptr; + SerialSourceLocReader* _sourceLocReader = nullptr; + SourceLoc _requestingSourceLoc; + RIFFSerialReader _riffReader; - void decodePtr(WitnessTable*& witnessTable, Decoder& decoder, WitnessTable*) - { - Decoder::WithObject withObject(decoder); - auto wt = RefPtr(new WitnessTable()); - decode(wt->baseType, decoder); - decode(wt->witnessedType, decoder); - decode(wt->isExtern, decoder); - decode(wt->m_requirementDictionary, decoder); - witnessTable = wt.detach(); - } + virtual ISerializerImpl* getBaseSerializer() override { return &_riffReader; } - void decodeValue(RequirementWitness& witness, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decodeEnum(witness.m_flavor, decoder); - switch (witness.m_flavor) - { - case RequirementWitness::Flavor::none: - break; + virtual void handleName(Name*& value) override; + virtual void handleSourceLoc(SourceLoc& value) override; + virtual void handleToken(Token& value) override; + virtual void handleASTNode(NodeBase*& outNode) override; + virtual void handleASTNodeContents(NodeBase* node) override; - case RequirementWitness::Flavor::declRef: - decode(witness.m_declRef, decoder); - break; + ModuleDecl* _readImportedModule(); + NodeBase* _readImportedDecl(); - case RequirementWitness::Flavor::val: - decode(witness.m_val, decoder); - break; + void _cleanUpASTNode(NodeBase* node) + { + if (auto expr = as(node)) + { + expr->checked = true; + } + else if (auto decl = as(node)) + { + decl->checkState = DeclCheckState::CapabilityChecked; - case RequirementWitness::Flavor::witnessTable: + if (auto genericDecl = as(node)) + { + _assignGenericParameterIndices(genericDecl); + } + else if (auto syntaxDecl = as(node)) + { + syntaxDecl->parseCallback = &parseSimpleSyntax; + syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); + } + else if (auto namespaceLikeDecl = as(node)) { - RefPtr object; - decode(object, decoder); - witness.m_obj = object; + auto declScope = _astBuilder->create(); + declScope->containerDecl = namespaceLikeDecl; + namespaceLikeDecl->ownedScope = declScope; } - break; } } - template - void decodePtr(T*& node, Decoder& decoder, Val*) - { - ValID id = decodeValID(decoder); - node = static_cast(getValByID(id)); - } - - template - void decodePtr(T*& node, Decoder& decoder, Decl*) + void _assignGenericParameterIndices(GenericDecl* genericDecl) { - DeclID id = decodeDeclID(decoder); - node = static_cast(getDeclByID(id)); - } - - template - void decodePtr(T*& node, Decoder& decoder, DeclBase*) - { - // This case is a bit of a hack. We need - // to identify whether we are looking at - // an indirection to a `Decl` (which would - // be serialized as an integer `DeclID`), - // or something else derived from `DeclBase`. - // - switch (decoder.getTag()) + int parameterCounter = 0; + for (auto m : genericDecl->members) { - default: - decodeASTNode(node, decoder); - break; - - case SerialBinary::kInt32FourCC: - case SerialBinary::kInt64FourCC: - case SerialBinary::kUInt32FourCC: - case SerialBinary::kUInt64FourCC: + if (auto typeParam = as(m)) + { + typeParam->parameterIndex = parameterCounter++; + } + else if (auto valParam = as(m)) { - DeclID id = decodeDeclID(decoder); - node = static_cast(getDeclByID(id)); + valParam->parameterIndex = parameterCounter++; } - break; } } +}; - template - void decodePtr(T*& node, Decoder& decoder, NodeBase*) - { - decodeASTNode(node, decoder); - } - +// +// We are matching up the corresponding `handle*()` operations from the +// `AST{Encoding|Decoding}Context` types here, so that it is easier +// to visually verify that they are serializing the same data with the +// same ordering. +// - void decodeValue(UnhandledCase, Decoder& decoder); +// +// AST{Encoding|Decoding}Context::handleName() +// - void decodeValue(String& value, Decoder& decoder) { value = decoder.decodeString(); } +void ASTEncodingContext::handleName(Name*& value) +{ + serialize(ASTSerializer(this), value->text); +} - void decodeValue(Token& value, Decoder& decoder) - { - decode(value.type, decoder); - decode(value.flags, decoder); - decode(value.loc, decoder); - if (decoder.decodeNull()) - { - } - else - { - Name* name = nullptr; - decode(name, decoder); - value.setName(name); - } - } +void ASTDecodingContext::handleName(Name*& value) +{ + String text; + serialize(ASTSerializer(this), text); + value = _astBuilder->getNamePool()->getName(text); +} - void decodeValue(NameLoc& value, Decoder& decoder) { decode(value.name, decoder); } +// +// AST{Encoding|Decoding}Context::handleSourceLoc() +// - void decodeValue(SemanticVersion& value, Decoder& decoder) +void ASTEncodingContext::handleSourceLoc(SourceLoc& value) +{ + ASTSerializer serializer(this); + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (_sourceLocWriter != nullptr) { - SemanticVersion::RawValue rawValue = decoder.decode(); - value.setRawValue(rawValue); + auto rawValue = _sourceLocWriter->addSourceLoc(value); + serialize(serializer, rawValue); } +} - void decodeValue(CapabilitySet& value, Decoder& decoder) +void ASTDecodingContext::handleSourceLoc(SourceLoc& value) +{ + ASTSerializer serializer(this); + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (hasElements(serializer)) { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) + SerialSourceLocData::SourceLoc rawValue; + serialize(serializer, rawValue); + + if (_sourceLocReader) { - CapabilityTargetSet targetSet; - decode(targetSet, decoder); - value.getCapabilityTargetSets()[targetSet.target] = targetSet; + value = _sourceLocReader->getSourceLoc(rawValue); } } +} - void decodeValue(CapabilityTargetSet& value, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decode(value.target, decoder); +// +// AST{Encoding|Decoding}Context::handleToken() +// - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - CapabilityStageSet stageSet; - decode(stageSet, decoder); - value.shaderStageSets[stageSet.stage] = stageSet; - } - } +void ASTDecodingContext::handleToken(Token& value) +{ + ASTSerializer serializer(this); - void decodeValue(CapabilityStageSet& value, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decode(value.stage, decoder); - decode(value.atomSet, decoder); - } + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.loc); + + serialize(serializer, value.flags); - void decodeValue(CapabilityAtomSet& value, Decoder& decoder) { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (hasElements(serializer)) { - CapabilityAtom atom; - decode(atom, decoder); - value.add(UInt(atom)); + String content; + serialize(serializer, content); + + // An important note here is that we cannot just + // call `value.setContent(...)` and pass in an + // `UnownedStringSlice` of `content`, because the + // `Token` will not take ownership of its own + // textual content. + // + // Instead, we need to get the text we just loaded + // into something that the `Token` can refer info, + // and the easiest way to accomplish that is to + // represent the text using a `Name`. + // + Name* name = _astBuilder->getNamePool()->getName(content); + value.setName(name); } } +} + +void ASTEncodingContext::handleToken(Token& value) +{ + ASTSerializer serializer(this); + + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.loc); + + TokenFlags flags = TokenFlags(value.flags & ~TokenFlag::Name); + serialize(serializer, flags); - template - void decodeValue(std::optional& outValue, Decoder& decoder) { - if (decoder.decodeNull()) - { - outValue.reset(); - } - else + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (value.hasContent()) { - T value; - decode(value, decoder); - outValue = value; + String content = value.getContent(); + serialize(serializer, content); } } +} - void decodeValue(SyntaxClass& syntaxClass, Decoder& decoder) - { - ASTNodeType nodeType; - decode(nodeType, decoder); - syntaxClass = SyntaxClass(nodeType); - } - - template - void decodeValue(DeclRef& declRef, Decoder& decoder) - { - decode(declRef.declRefBase, decoder); - } +// +// AST{Encoding|Decoding}Context::handleASTNode() +// - void decodeValue(ValNodeOperand& value, Decoder& decoder) +void ASTEncodingContext::handleASTNode(NodeBase*& node) +{ + if (auto decl = as(node)) { - Decoder::WithKeyValuePair withPair(decoder); - - decodeEnum(value.kind, decoder); - switch (value.kind) + if (auto importedFromModule = _findModuleDeclWasImportedFrom(decl)) { - case ValNodeOperandKind::ConstantValue: - decode(value.values.intOperand, decoder); - break; - - case ValNodeOperandKind::ValNode: + if (decl == importedFromModule) { - Val* val = nullptr; - decode(val, decoder); - value.values.nodeOperand = val; + _writeImportedModule(importedFromModule); + return; } - break; - - case ValNodeOperandKind::ASTNode: + else { - Decl* decl = nullptr; - decode(decl, decoder); - value.values.nodeOperand = decl; + _writeImportedDecl(decl, importedFromModule); + return; } - break; } } - void decodeValue(TypeExp& value, Decoder& decoder) { decode(value.type, decoder); } + ASTSerializer serializer(this); - void decodeValue(QualType& value, Decoder& decoder) + if (auto val = as(node)) { - Decoder::WithObject withObject(decoder); - decode(value.type, decoder); - decode(value.isLeftValue, decoder); - decode(value.hasReadOnlyOnTarget, decoder); - decode(value.isWriteOnly, decoder); - } + val = val->resolve(); - void decodeValue(MatrixCoord& value, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - decode(value.row, decoder); - decode(value.col, decoder); + // On the reading side of things, sublcasses of `Val` + // are deduplicated as part of creation, and will read the + // operands out immediately, so we mirror that approach + // on the writing side to make sure the code is consistent. + // + serialize(serializer, val->astNodeType); + serialize(serializer, val->m_operands); } - - void decodeValue(SPIRVAsmOperand::Flavor& value, Decoder& decoder) + else { - decodeEnum(value, decoder); + serialize(serializer, node->astNodeType); + deferSerializeObjectContents(serializer, node); } +} - void decodeValue(SPIRVAsmOperand& value, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - decode(value.flavor, decoder); - decode(value.token, decoder); - decode(value.expr, decoder); - decode(value.bitwiseOrWith, decoder); - decode(value.knownValue, decoder); - decode(value.wrapInId, decoder); - decode(value.type, decoder); - } +void ASTDecodingContext::handleASTNode(NodeBase*& outNode) +{ + ASTSerializer serializer(this); - void decodeValue(SPIRVAsmInst& value, Decoder& decoder) + ASTNodeType typeTag; + serialize(serializer, typeTag); + switch (_getPseudoASTNodeType(typeTag)) { - Decoder::WithObject withObject(decoder); - decode(value.opcode, decoder); - decode(value.operands, decoder); - } + default: + break; + case PseudoASTNodeType::ImportedModule: + outNode = _readImportedModule(); + return; - template - void decodeEnum(T& value, Decoder& decoder) - { - value = T(decoder.decode()); + case PseudoASTNodeType::ImportedDecl: + outNode = _readImportedDecl(); + return; } - template - void decodeSimpleValue(T& value, Decoder& decoder) + auto syntaxClass = SyntaxClass(typeTag); + if (syntaxClass.isSubClassOf()) { - value = decoder.decode(); - } + // Subclasses of `Val` are deduplicated as part + // of creation, so we need to read in their + // operands before we can create them, rather + // than allocating the object up front and + // then deserializing its content into it later. - void decodeValue(bool& value, Decoder& decoder) { value = decoder.decodeBool(); } - void decodeValue(Int32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(Int64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(UInt32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(UInt64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(float& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(double& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + ValNodeDesc desc; + desc.type = syntaxClass; + serialize(serializer, desc.operands); - void decodeValue(uint8_t& value, Decoder& decoder) - { - value = uint8_t(decoder.decode()); - } + desc.init(); - void decodeValue(DeclVisibility& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(BaseType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(BuiltinRequirementKind& value, Decoder& decoder) - { - decodeEnum(value, decoder); - } - void decodeValue(ASTNodeType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(ImageFormat& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TypeTag& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TryClauseType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(CapabilityAtom& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(PreferRecomputeAttribute::SideEffectBehavior& value, Decoder& decoder) - { - decodeEnum(value, decoder); - } - void decodeValue(LogicOperatorShortCircuitExpr::Flavor& value, Decoder& decoder) - { - decodeEnum(value, decoder); - } - void decodeValue(TreatAsDifferentiableExpr::Flavor& value, Decoder& decoder) - { - decodeEnum(value, decoder); + auto node = _astBuilder->_getOrCreateImpl(std::move(desc)); + outNode = node; } - void decodeValue(DeclAssociationKind& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TokenType& value, Decoder& decoder) { decodeEnum(value, decoder); } - - - void decodeValue(SourceLoc& value, Decoder& decoder) + else { - if (!decoder.decodeNull()) - { - SerialSourceLocData::SourceLoc intermediate; - decoder.decode(intermediate); + auto node = syntaxClass.createInstance(_astBuilder); + outNode = node; - if (_sourceLocReader) - { - auto sourceLoc = _sourceLocReader->getSourceLoc(intermediate); - value = sourceLoc; - } - } + deferSerializeObjectContents(serializer, node); } +} - template - void decodeValue(T*& ptr, Decoder& decoder) - { - if (decoder.decodeNull()) - ptr = nullptr; - else - decodePtr(ptr, decoder, (T*)nullptr); - } +// +// AST{Encoding|Decoding}Context::handleASTNodeContents() +// - template - void decodeValue(RefPtr& ptr, Decoder& decoder) - { - if (decoder.decodeNull()) - ptr = nullptr; - else - { - // Hi Future Tess, - // - // The next step here is decoding logic for `WitnessTable`s. - // +void ASTEncodingContext::handleASTNodeContents(NodeBase* node) +{ + ASTSerializer serializer(this); + serializeASTNodeContents(serializer, node); +} - decodePtr(*ptr.writeRef(), decoder, (T*)nullptr); - } - } +void ASTDecodingContext::handleASTNodeContents(NodeBase* node) +{ + ASTSerializer serializer(this); + serializeASTNodeContents(serializer, node); - void decodeValue(Modifiers& modifiers, Decoder& decoder) - { - Modifier** link = &modifiers.first; + _cleanUpASTNode(node); +} - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Modifier* modifier = nullptr; - decode(modifier, decoder); +// +// AST{Encoding|Decoding}Context::_{write|read}ImportedModule() +// - *link = modifier; - link = &modifier->next; - } - } +void ASTEncodingContext::_writeImportedModule(ModuleDecl* moduleDecl) +{ + ASTNodeType type = _getAsASTNodeType(PseudoASTNodeType::ImportedModule); + auto moduleName = moduleDecl->getName(); - template - void decodeValue(ShortList& array, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - T element; - decode(element, decoder); - array.add(element); - } - } + ASTSerializer serializer(this); + serialize(serializer, type); + serialize(serializer, moduleName); +} +ModuleDecl* ASTDecodingContext::_readImportedModule() +{ + ASTSerializer serializer(this); - template - void decode(List& array, Decoder& decoder) + Name* moduleName = nullptr; + serialize(serializer, moduleName); + auto module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); + if (!module) { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - T element; - decode(element, decoder); - array.add(element); - } + SLANG_ABORT_COMPILATION("failed to load an imported module during AST deserialization"); } + return module->getModuleDecl(); +} - template - void decode(T (&array)[N], Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - for (auto& element : array) - { - decode(element, decoder); - } - } +// +// AST{Encoding|Decoding}Context::_{write|read}ImportedModule() +// - template - void decode(OrderedDictionary& dictionary, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Decoder::WithKeyValuePair withPair(decoder); +void ASTEncodingContext::_writeImportedDecl(Decl* decl, ModuleDecl* importedFromModuleDecl) +{ + ASTNodeType type = _getAsASTNodeType(PseudoASTNodeType::ImportedDecl); + auto mangledName = getMangledName(getCurrentASTBuilder(), decl); - K key; - V value; - decode(key, decoder); - decode(value, decoder); + ASTSerializer serializer(this); + serialize(serializer, type); + serialize(serializer, importedFromModuleDecl); + serialize(serializer, mangledName); +} - dictionary.add(key, value); - } - } +NodeBase* ASTDecodingContext::_readImportedDecl() +{ + ASTSerializer serializer(this); - template - void decode(Dictionary& dictionary, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Decoder::WithKeyValuePair withPair(decoder); + ModuleDecl* importedFromModuleDecl = nullptr; + String mangledName; - K key; - V value; - decode(key, decoder); - decode(value, decoder); + serialize(serializer, importedFromModuleDecl); + serialize(serializer, mangledName); - dictionary.add(key, value); - } + auto importedFromModule = importedFromModuleDecl->module; + if (!importedFromModule) + { + return nullptr; } - template - void decode(T& outValue, Decoder& decoder) + auto importedDecl = + importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); + if (!importedDecl) { - decodeValue(outValue, decoder); + SLANG_ABORT_COMPILATION( + "failed to load an imported declaration during AST deserialization"); } + return importedDecl; +} -#if 0 // FIDDLE TEMPLATE: -%for _,T in ipairs(Slang.NodeBase.subclasses) do - void _decodeDataOf($T* obj, Decoder& decoder) - { -% if T.directSuperClass then - _decodeDataOf(static_cast<$(T.directSuperClass)*>(obj), decoder); -% end -% for _,f in ipairs(T.directFields) do - decode(obj->$f, decoder); -% end - } -%end -#else // FIDDLE OUTPUT: -#define FIDDLE_GENERATED_OUTPUT_ID 1 -#include "slang-serialize-ast.cpp.fiddle" -#endif // FIDDLE END -}; +// +// {write|read}SerializedModuleAST() +// + +void writeSerializedModuleAST( + RIFF::BuildCursor& cursor, + ModuleDecl* moduleDecl, + SerialSourceLocWriter* sourceLocWriter) +{ + // TODO: we might want to have a more careful pass here, + // where we only encode the public declarations. + + ASTEncodingContext context(cursor, moduleDecl, sourceLocWriter); + serialize(ASTSerializer(&context), moduleDecl); +} ModuleDecl* readSerializedModuleAST( Linkage* linkage, @@ -1546,9 +919,10 @@ ModuleDecl* readSerializedModuleAST( { ASTDecodingContext context(linkage, astBuilder, sink, chunk, sourceLocReader, requestingSourceLoc); - context.decodeAll(); - auto node = context.getDeclByID(0); - auto moduleDecl = as(node); + + ModuleDecl* moduleDecl = nullptr; + serialize(ASTSerializer(&context), moduleDecl); return moduleDecl; } + } // namespace Slang -- cgit v1.2.3