diff options
| author | Theresa Foley <10618364+tangent-vector@users.noreply.github.com> | 2025-05-20 21:55:39 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-21 04:55:39 +0000 |
| commit | 9059093bc764e901a9c4aaeb12471bf32028874f (patch) | |
| tree | 7058871ce0ec4397b6e8996506357e41ebb2517d /source/slang/slang-serialize-ast.cpp | |
| parent | 52d70f37f66d8fc34bc142386490bdcde0fc7db0 (diff) | |
Generalize serialization system used for AST (#7126)
This change takes the new approach to serialization that was used for the AST and generalizes it in a few ways:
* The new approach is no longer tangled up with the RIFF format.
The serialization system supports multiple different implementations of the underlying format.
The existing RIFF format is now supported as one back-end, but support for others will follow in subsequent changes.
* The new approach is no longer deeply specialized to AST serialization.
The old code had things like serialization for `List`s and `Dictionary`s, but it was embedded inside the `AST{Encoding|Decoding}Context`, and thus couldn't be leveraged for other serialization tasks.
This change factors out a completely AST-independent `Serializer` implementation, with an `ASTSerializer` layered on top of it to provide the additional context needed.
* There is less duplication of code between reading and writing of serialized data.
The old code had both the `ASTEncodingContext` and `ASTDecodingContext`, with serialization logic for most types being implemented in both, but with the constraint that those implementations needed to be kept in sync to avoid serialization-related runtime failures.
A key property of the revamped approach is that a single `serialize()` method for a type implements both the reading and writing directions of serialization.
Diffstat (limited to 'source/slang/slang-serialize-ast.cpp')
| -rw-r--r-- | source/slang/slang-serialize-ast.cpp | 1932 |
1 files changed, 653 insertions, 1279 deletions
diff --git a/source/slang/slang-serialize-ast.cpp b/source/slang/slang-serialize-ast.cpp index 1ef532ad1..9a61dbc5a 100644 --- a/source/slang/slang-serialize-ast.cpp +++ b/source/slang/slang-serialize-ast.cpp @@ -5,6 +5,7 @@ #include "slang-compiler.h" #include "slang-diagnostics.h" #include "slang-mangle.h" +#include "slang-serialize-riff.h" namespace Slang { @@ -13,654 +14,512 @@ namespace Slang // NodeBase* parseSimpleSyntax(Parser* parser, void* userData); +// +// Many of the types used in the AST can be serialized using +// just the `Serializer` type, so we will handle all of those first. +// -struct ASTEncodingContext +void serialize(Serializer const& serializer, ASTNodeType& value) { -private: - Encoder* encoder; - struct UnhandledCase - { - }; - - typedef Int DeclID; - Dictionary<Decl*, DeclID> mapDeclToID; - List<Decl*> decls; - - struct ImportedDeclInfo - { - Int moduleIndex = -1; - Decl* decl; - }; - List<ImportedDeclInfo> importedDecls; - - typedef Int ValID; - Dictionary<Val*, ValID> mapValToID; - List<Val*> vals; - - ModuleDecl* _module = nullptr; - - SerialSourceLocWriter* _sourceLocWriter = nullptr; + serializeEnum(serializer, value); +} -public: - ASTEncodingContext(Encoder* encoder, ModuleDecl* module, SerialSourceLocWriter* sourceLocWriter) - : encoder(encoder), _module(module), _sourceLocWriter(sourceLocWriter) - { - } +void serialize(Serializer const& serializer, TypeTag& value) +{ + serializeEnum(serializer, value); +} - template<typename T> - void encodeASTNodeContent(T* node) - { - Encoder::WithObject withObject(encoder); +void serialize(Serializer const& serializer, BaseType& value) +{ + serializeEnum(serializer, value); +} - ASTNodeDispatcher<T, void>::dispatch(node, [&](auto n) { _encodeDataOf(n); }); - } +void serialize(Serializer const& serializer, TryClauseType& value) +{ + serializeEnum(serializer, value); +} - void flush() - { - auto containerChunk = encoder->getRIFFChunk(); +void serialize(Serializer const& serializer, DeclVisibility& value) +{ + serializeEnum(serializer, value); +} - RIFF::ChunkBuilder* declChunk = nullptr; - RIFF::ChunkBuilder* importedDeclChunk = nullptr; - RIFF::ChunkBuilder* valChunk = nullptr; - { - Encoder::WithArray withList(encoder); - declChunk = encoder->getRIFFChunk(); - } - { - Encoder::WithArray withList(encoder); - importedDeclChunk = encoder->getRIFFChunk(); - } - { - Encoder::WithArray withList(encoder); - valChunk = encoder->getRIFFChunk(); - } - Int declIndex = 0; - Int importedDeclIndex = 0; - Int valIndex = 0; +void serialize(Serializer const& serializer, BuiltinRequirementKind& value) +{ + serializeEnum(serializer, value); +} - bool done = false; - do - { - done = true; - while (declIndex < decls.getCount()) - { - done = false; - encoder->setRIFFChunk(declChunk); - encodeASTNodeContent(decls[declIndex++]); - } - while (importedDeclIndex < importedDecls.getCount()) - { - done = false; - encoder->setRIFFChunk(importedDeclChunk); - encodeImportedDecl(importedDecls[importedDeclIndex++]); - } - while (valIndex < vals.getCount()) - { - done = false; - encoder->setRIFFChunk(valChunk); - encodeASTNodeContent(vals[valIndex++]); - } - } while (!done); +void serialize(Serializer const& serializer, ImageFormat& value) +{ + serializeEnum(serializer, value); +} - encoder->setRIFFChunk(containerChunk); - } +void serialize(Serializer const& serializer, PreferRecomputeAttribute::SideEffectBehavior& value) +{ + serializeEnum(serializer, value); +} - ModuleDecl* findModuleForDecl(Decl* decl) - { - for (auto d = decl; d; d = d->parentDecl) - { - if (auto m = as<ModuleDecl>(d)) - return m; - } - return nullptr; - } +void serialize(Serializer const& serializer, TreatAsDifferentiableExpr::Flavor& value) +{ + serializeEnum(serializer, value); +} - ModuleDecl* findModuleDeclWasImportedFrom(Decl* decl) - { - auto declModule = findModuleForDecl(decl); - if (declModule == nullptr) - return nullptr; - if (declModule == _module) - return nullptr; - return declModule; - } +void serialize(Serializer const& serializer, LogicOperatorShortCircuitExpr::Flavor& value) +{ + serializeEnum(serializer, value); +} - DeclID getDeclID(Decl* decl) - { - SLANG_ASSERT(decl != nullptr); +void serialize(Serializer const& serializer, RequirementWitness::Flavor& value) +{ + serializeEnum(serializer, value); +} - if (auto found = mapDeclToID.tryGetValue(decl)) - return *found; +void serialize(Serializer const& serializer, CapabilityAtom& value) +{ + serializeEnum(serializer, value); +} - // We need to detect whether the declaration is an - // imported one, or one from this module itself. - // - // Imported declarations need to be handled very - // differently, since they'll involve resolving - // references to those other modules, and the - // declarations within them. - // - if (auto importedFromModule = findModuleDeclWasImportedFrom(decl)) - { - DeclID importedFromModuleDeclID = 0; - if (decl != importedFromModule) - { - importedFromModuleDeclID = getDeclID(importedFromModule); - } +void serialize(Serializer const& serializer, DeclAssociationKind& value) +{ + serializeEnum(serializer, value); +} - DeclID id = ~importedDecls.getCount(); - mapDeclToID.add(decl, id); +void serialize(Serializer const& serializer, TokenType& value) +{ + serializeEnum(serializer, value); +} - ImportedDeclInfo info; - info.moduleIndex = ~importedFromModuleDeclID; - info.decl = decl; - importedDecls.add(info); +void serialize(Serializer const& serializer, ValNodeOperandKind& value) +{ + serializeEnum(serializer, value); +} - return id; - } - else - { - DeclID id = decls.getCount(); - decls.add(decl); - mapDeclToID.add(decl, id); +void serialize(Serializer const& serializer, SPIRVAsmOperand::Flavor& value) +{ + serializeEnum(serializer, value); +} - return id; - } - } +void serialize(Serializer const& serializer, MatrixCoord& value) +{ + SLANG_SCOPED_SERIALIZER_TUPLE(serializer); + serialize(serializer, value.row); + serialize(serializer, value.col); +} - void encodePtr(Decl* decl) +void serializePtr(Serializer const& serializer, DiagnosticInfo const*& value, DiagnosticInfo const*) +{ + Int32 id = 0; + if (isWriting(serializer)) { - DeclID id = getDeclID(decl); - encoder->encode(id); + id = value->id; + serialize(serializer, id); } - - ValID getValID(Val* val) + else { - SLANG_ASSERT(val != nullptr); - - if (auto found = mapValToID.tryGetValue(val)) - return *found; - - // In order to ensure that values can be fully constructed - // from the get-go (so that they will get cached correctly), - // we conspire to ensure that every value is preceded by - // all of its operands. - // - for (auto operand : val->m_operands) - { - switch (operand.kind) - { - default: - break; - - case ValNodeOperandKind::ValNode: - if (auto operandNode = operand.values.nodeOperand) - { - SLANG_ASSERT(as<Val>(operandNode)); - getValID(static_cast<Val*>(operandNode)); - } - break; - - case ValNodeOperandKind::ASTNode: - if (auto operandNode = operand.values.nodeOperand) - { - SLANG_ASSERT(as<Decl>(operandNode)); - getDeclID(static_cast<Decl*>(operandNode)); - } - break; - } - } - auto resolved = val->resolve(); - if (resolved != val) - { - getValID(resolved); - } - - ValID id = vals.getCount(); - vals.add(val); - mapValToID.add(val, id); - return id; + serialize(serializer, id); + value = getDiagnosticsLookup()->getDiagnosticById(id); } +} - void encodePtr(Val* val) +void serialize(Serializer const& serializer, SemanticVersion& value) +{ + auto raw = value.getRawValue(); + serialize(serializer, raw); + value = SemanticVersion::fromRaw(raw); +} + +void serialize(Serializer const& serializer, SyntaxClass<NodeBase>& value) +{ + ASTNodeType raw; + if (isWriting(serializer)) { - ValID id = getValID(val); - encoder->encode(id); + raw = value.getTag(); } - - void encodeImportedDecl(ImportedDeclInfo const& info) + serialize(serializer, raw); + if (isReading(serializer)) { - Encoder::WithKeyValuePair withPair(encoder); - encode(info.moduleIndex); - auto decl = info.decl; - if (auto importedModuleDecl = as<ModuleDecl>(decl)) - { - SLANG_ASSERT(info.moduleIndex == -1); - encode(importedModuleDecl->getName()); - } - else - { - auto mangledName = getMangledName(getCurrentASTBuilder(), decl); - encode(mangledName); - } + value = SyntaxClass<NodeBase>(raw); } +} - void encodePtr(Modifier* modifier) { encodeASTNodeContent(modifier); } - void encodePtr(Expr* expr) { encodeASTNodeContent(expr); } - void encodePtr(Stmt* stmt) { encodeASTNodeContent(stmt); } +// +// Many types in the AST need additional context (beyond +// what the `Serializer` has) in order to serialize +// themselves or their members. +// +// We define a custom serializer interface to capture +// the cases that can't be handled by a `Serializer` +// alone. +// - void encodePtr(Name* name) { encode(name->text); } +/// Interface for AST serialization +struct ASTSerializerImpl +{ +public: + virtual void handleASTNode(NodeBase*& value) = 0; + virtual void handleASTNodeContents(NodeBase* value) = 0; + virtual void handleName(Name*& value) = 0; + virtual void handleSourceLoc(SourceLoc& value) = 0; + virtual void handleToken(Token& value) = 0; - void encodePtr(MarkupEntry* entry) - { - // TODO: is this case needed? - SLANG_UNUSED(entry); - } + // Note that this type does *not* inherit from `ISerializerImpl`. + // + // We want to decouple the AST-specific context information + // from the lower-level details of the serialization format. + // + // Instead of using inheritance, we expect that any + // `ASTSerializerImpl` will aggregate a lower-level + // serializer, and the interface exposes access to + // that base serializer implementation. - void encodePtr(DeclAssociationList* list) - { - // We serialize this as if it were a simple list - // of key-value pairs because... well... that's - // what it amounts to in practice. - // - Encoder::WithArray withArray(encoder); - for (auto association : list->associations) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(association->kind); - encode(association->decl); - } - } + virtual ISerializerImpl* getBaseSerializer() = 0; +}; - void encodePtr(CandidateExtensionList* list) { encode(list->candidateExtensions); } +/// Specialization of `Serializer_` for AST serialization. +template<> +struct Serializer_<ASTSerializerImpl> : SerializerBase<ASTSerializerImpl> +{ +public: + using SerializerBase::SerializerBase; - void encodePtr(WitnessTable* witnessTable) - { - Encoder::WithObject withObject(encoder); - encode(witnessTable->baseType); - encode(witnessTable->witnessedType); - encode(witnessTable->isExtern); - - // TODO(tfoley): In theory we should be able to streamline - // this so that we only encode the requirements that we - // absolutely need to (which basically amounts to `associatedtype` - // requirements where the satisfying type is part of the public - // API of the type). - // - encode(witnessTable->m_requirementDictionary); - } + // + // In order to allow an `ASTSerializer` to be used with + // functions that expect an ordinary `Serializer`, we + // implement an implicit conversion operator. + // - void encodeValue(RequirementWitness const& witness) - { - Encoder::WithKeyValuePair withPair(encoder); - encodeEnum(witness.m_flavor); - switch (witness.m_flavor) - { - case RequirementWitness::Flavor::none: - break; + operator Serializer() const { return Serializer(get()->getBaseSerializer()); } +}; - case RequirementWitness::Flavor::declRef: - encode(witness.m_declRef); - break; +/// Context type for AST serialization. +using ASTSerializer = Serializer_<ASTSerializerImpl>; - case RequirementWitness::Flavor::val: - encode(witness.m_val); - break; +template<typename T> +void serializeObject(ASTSerializer const& serializer, T*& value, NodeBase*) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serializer->handleASTNode(*(NodeBase**)&value); +} - case RequirementWitness::Flavor::witnessTable: - encode((WitnessTable*)witness.m_obj.Ptr()); - break; - } - } +void serializeObjectContents(ASTSerializer const& serializer, NodeBase* value, NodeBase*) +{ + serializer->handleASTNodeContents(value); +} - void encodePtr(DiagnosticInfo* info) { encode(Int(info->id)); } +template<typename T> +void serialize(ASTSerializer const& serializer, DeclRef<T>& value) +{ + serialize(serializer, value.declRefBase); +} - void encodePtr(DeclBase* declBase) +void serialize(ASTSerializer const& serializer, SourceLoc& value) +{ + serializer->handleSourceLoc(value); +} + +void serialize(ASTSerializer const& serializer, RequirementWitness& value) +{ + SLANG_SCOPED_SERIALIZER_TAGGED_UNION(serializer); + serialize(serializer, value.m_flavor); + switch (value.m_flavor) { - if (auto decl = as<Decl>(declBase)) - { - encodePtr(decl); - } - else - { - encodeASTNodeContent(declBase); - } - } + case RequirementWitness::Flavor::none: + break; - void encodeValue(UnhandledCase); + case RequirementWitness::Flavor::declRef: + serialize(serializer, value.m_declRef); + break; - void encodeValue(String const& value) { encoder->encode(value); } + case RequirementWitness::Flavor::val: + serialize(serializer, value.m_val); + break; - void encodeValue(Token const& value) - { - encode(value.type); - encode(TokenFlags(value.flags & ~TokenFlag::Name)); - encode(value.loc); - if (value.hasContent()) - encoder->encodeString(value.getContent()); - else - encode(nullptr); + case RequirementWitness::Flavor::witnessTable: + serialize(serializer, value.m_obj); + break; } +} - void encodeValue(NameLoc const& value) { encode(value.name); } - - void encodeValue(SemanticVersion value) { encoder->encode(value.getRawValue()); } +void serialize(ASTSerializer const& serializer, WitnessTable& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.baseType); + serialize(serializer, value.witnessedType); + serialize(serializer, value.isExtern); + + // TODO(tfoley): In theory we should be able to streamline + // this so that we only encode the requirements that we + // absolutely need to (which basically amounts to `associatedtype` + // requirements where the satisfying type is part of the public + // API of the type). + // + serialize(serializer, value.m_requirementDictionary); +} - void encodeValue(CapabilitySet const& value) +void serialize(Serializer const& serializer, CapabilityAtomSet& value) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) { - // While the `CapabilityTargetSets` type is a dictionary, - // in practice each entry already embeds its own key - // (the target atom), so we can encode this as just - // an array of the `CapabilityTargetSet` values. - // - Encoder::WithArray withArray(encoder); - for (auto pair : value.getCapabilityTargetSets()) + for (auto rawAtom : value) { - encode(pair.second); + auto atom = CapabilityAtom(rawAtom); + serialize(serializer, atom); } } - - void encodeValue(CapabilityTargetSet const& value) + else { - Encoder::WithKeyValuePair withPair(encoder); - encode(value.target); - - // Similar to the case for the `CapabilityTargetSets` above, - // each `CapabilityStageSet` already includes the stage atom, - // so we can simply encode the values from the dictionary. - // - Encoder::WithArray withArray(encoder); - for (auto pair : value.shaderStageSets) + while (hasElements(serializer)) { - encode(pair.second); + CapabilityAtom atom; + serialize(serializer, atom); + value.add(UInt(atom)); } } +} - void encodeValue(CapabilityStageSet const& value) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(value.stage); - encode(value.atomSet); - } +void serialize(Serializer const& serializer, CapabilityStageSet& value) +{ + serialize(serializer, value.atomSet); +} - void encodeValue(CapabilityAtomSet const& value) +void serialize(Serializer const& serializer, CapabilityTargetSet& value) +{ + serialize(serializer, value.shaderStageSets); + + // The value for each entry in `shaderStageSets` have + // a `stage` field that is redundant with the key for + // that entry. Rather than serialize the key as part + // of the `CapabilityStageSet` type, we instead copy + // it over from the key to the value in the case where + // we are reading. + // + if (isReading(serializer)) { - Encoder::WithArray withArray(encoder); - for (auto rawAtom : value) - { - encode(CapabilityAtom(rawAtom)); - } + for (auto& p : value.shaderStageSets) + p.second.stage = p.first; } +} - template<typename T> - void encodeValue(std::optional<T> const& value) +void serialize(Serializer const& serializer, CapabilitySet& value) +{ + serialize(serializer, value.getCapabilityTargetSets()); + + // The value for each entry in `getCapabilityTargetSets()` have + // a `target` field that is redundant with the key for + // that entry. Rather than serialize the key as part + // of the `CapabilityTargetSet` type, we instead copy + // it over from the key to the value in the case where + // we are reading. + // + if (isReading(serializer)) { - if (value) - encodeValue(*value); - else - encoder->encode(nullptr); + for (auto& p : value.getCapabilityTargetSets()) + p.second.target = p.first; } +} + +void serialize(ASTSerializer const& serializer, CandidateExtensionList& value) +{ + serialize(serializer, value.candidateExtensions); +} + +void serialize(ASTSerializer const& serializer, DeclAssociation& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.kind); + serialize(serializer, value.decl); +} - void encodeValue(SyntaxClass<NodeBase> const& value) { encode(value.getTag()); } +void serialize(ASTSerializer const& serializer, DeclAssociationList& value) +{ + serialize(serializer, value.associations); +} - template<typename T> - void encodeValue(DeclRef<T> const& value) +void serialize(ASTSerializer const& serializer, Modifiers& value) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) { - encode((DeclRefBase*)value); + for (auto modifier : value) + { + serialize(serializer, modifier); + } } - - void encodeValue(ValNodeOperand value) + else { - Encoder::WithKeyValuePair withPair(encoder); + Modifier** link = &value.first; - encodeEnum(value.kind); - switch (value.kind) + while (hasElements(serializer)) { - case ValNodeOperandKind::ConstantValue: - encode(value.values.intOperand); - break; - - case ValNodeOperandKind::ValNode: - encode(static_cast<Val*>(value.values.nodeOperand)); - break; + Modifier* modifier = nullptr; + serialize(serializer, modifier); - case ValNodeOperandKind::ASTNode: - { - if (auto decl = as<Decl>(value.values.nodeOperand)) - { - encode(decl); - } - else - { - SLANG_UNEXPECTED("AST node operand of `Val` was expected to be a `Decl`"); - } - } - break; + *link = modifier; + link = &modifier->next; } } +} - void encodeValue(TypeExp value) { encode(value.type); } +void serialize(ASTSerializer const& serializer, TypeExp& value) +{ + serialize(serializer, value.type); +} - void encodeValue(QualType value) - { - Encoder::WithObject withObject(encoder); - encode(value.type); - encode(value.isLeftValue); - encode(value.hasReadOnlyOnTarget); - encode(value.isWriteOnly); - } +void serialize(ASTSerializer const& serializer, QualType& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.isLeftValue); + serialize(serializer, value.hasReadOnlyOnTarget); + serialize(serializer, value.isWriteOnly); +} - void encodeValue(MatrixCoord value) - { - Encoder::WithObject withObject(encoder); - encode(value.row); - encode(value.col); - } +void serialize(ASTSerializer const& serializer, Token& value) +{ + serializer->handleToken(value); +} - void encodeValue(SPIRVAsmOperand::Flavor const& value) { encodeEnum(value); } +void serialize(ASTSerializer const& serializer, SPIRVAsmOperand& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.flavor); + serialize(serializer, value.token); + serialize(serializer, value.expr); + serialize(serializer, value.bitwiseOrWith); + serialize(serializer, value.knownValue); + serialize(serializer, value.wrapInId); + serialize(serializer, value.type); +} - void encodeValue(SPIRVAsmOperand const& value) - { - Encoder::WithObject withObject(encoder); - encode(value.flavor); - encode(value.token); - encode(value.expr); - encode(value.bitwiseOrWith); - encode(value.knownValue); - encode(value.wrapInId); - encode(value.type); - } +void serialize(ASTSerializer const& serializer, SPIRVAsmInst& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.opcode); + serialize(serializer, value.operands); +} - void encodeValue(SPIRVAsmInst const& value) +void serialize(ASTSerializer const& serializer, ValNodeOperand& value) +{ + SLANG_SCOPED_SERIALIZER_TAGGED_UNION(serializer); + serialize(serializer, value.kind); + switch (value.kind) { - Encoder::WithObject withObject(encoder); - encode(value.opcode); - encode(value.operands); - } + case ValNodeOperandKind::ConstantValue: + serialize(serializer, value.values.intOperand); + break; - - template<typename T, typename = std::enable_if_t<std::is_same_v<T, bool>>> - void encodeValue(T value) - { - encoder->encodeBool(value); + case ValNodeOperandKind::ValNode: + case ValNodeOperandKind::ASTNode: + serialize(serializer, value.values.nodeOperand); + break; } +} - void encodeValue(Int32 value) { encoder->encode(value); } - void encodeValue(UInt32 value) { encoder->encode(value); } - void encodeValue(Int64 value) { encoder->encode(value); } - void encodeValue(UInt64 value) { encoder->encode(value); } - void encodeValue(float value) { encoder->encode(value); } - void encodeValue(double value) { encoder->encode(value); } - - void encodeValue(uint8_t value) { encoder->encode(UInt32(value)); } - - void encodeValue(std::nullptr_t) { encoder->encode(nullptr); } +void serializeObject(ASTSerializer const& serializer, Name*& value, Name*) +{ + serializer->handleName(value); +} - template<typename T> - void encodeEnum(T value) - { - encoder->encode(Int32(value)); - } +void serialize(ASTSerializer const& serializer, NameLoc& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.name); + serialize(serializer, value.loc); +} - void encodeValue(DeclVisibility value) { encodeEnum(value); } - void encodeValue(BaseType value) { encodeEnum(value); } - void encodeValue(BuiltinRequirementKind value) { encodeEnum(value); } - void encodeValue(ASTNodeType value) { encodeEnum(value); } - void encodeValue(ImageFormat value) { encodeEnum(value); } - void encodeValue(TypeTag value) { encodeEnum(value); } - void encodeValue(TryClauseType value) { encodeEnum(value); } - void encodeValue(CapabilityAtom value) { encodeEnum(value); } - void encodeValue(DeclAssociationKind value) { encodeEnum(value); } - void encodeValue(TokenType value) { encodeEnum(value); } - - void encodeValue(SourceLoc value) - { - if (!_sourceLocWriter) - { - encoder->encode(nullptr); - } - else - { - auto intermediate = _sourceLocWriter->addSourceLoc(value); - encoder->encode(intermediate); +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do +void _serializeASTNodeContents(ASTSerializer const& serializer, $T* value) +{ + SLANG_UNUSED(serializer); + SLANG_UNUSED(value); +% if T.directSuperClass then + _serializeASTNodeContents(serializer, static_cast<$(T.directSuperClass)*>(value)); +% end +% for _,f in ipairs(T.directFields) do + serialize(serializer, value->$f); +% end } - } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-serialize-ast.cpp.fiddle" +#endif // FIDDLE END - template<typename T> - void encodeValue(T const* ptr) - { - if (!ptr) - { - encoder->encode(nullptr); - } - else - { - encodePtr(const_cast<T*>(ptr)); - } - } +void serializeASTNodeContents(ASTSerializer const& serializer, NodeBase* node) +{ + ASTNodeDispatcher<NodeBase, void>::dispatch( + node, + [&](auto n) { _serializeASTNodeContents(serializer, n); }); +} - template<typename T> - void encodeValue(RefPtr<T> const& ptr) - { - if (!ptr) - { - encoder->encode(nullptr); - } - else - { - encodePtr(ptr.Ptr()); - } - } +enum class PseudoASTNodeType +{ + None, + ImportedModule, + ImportedDecl, +}; - void encodeValue(Modifiers const& modifiers) - { - Encoder::WithArray withArray(encoder); - for (auto m : const_cast<Modifiers&>(modifiers)) - { - encode(m); - } - } +static PseudoASTNodeType _getPseudoASTNodeType(ASTNodeType type) +{ + return int(type) < 0 ? PseudoASTNodeType(~int(type)) : PseudoASTNodeType::None; +} + +static ASTNodeType _getAsASTNodeType(PseudoASTNodeType type) +{ + return ASTNodeType(~int(type)); +} - template<typename T, int N> - void encodeValue(ShortList<T, N> const& array) +struct ASTEncodingContext : ASTSerializerImpl +{ +public: + ASTEncodingContext( + RIFF::BuildCursor& cursor, + ModuleDecl* module, + SerialSourceLocWriter* sourceLocWriter) + : _writer(cursor.getCurrentChunk()), _module(module), _sourceLocWriter(sourceLocWriter) { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } } +private: + RIFFSerialWriter _writer; + ModuleDecl* _module = nullptr; + SerialSourceLocWriter* _sourceLocWriter = nullptr; - template<typename T> - void encode(List<T> const& array) - { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } - } + virtual ISerializerImpl* getBaseSerializer() override { return &_writer; } - template<typename T, size_t N> - void encode(T const (&array)[N]) - { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } - } + virtual void handleName(Name*& value) override; + virtual void handleSourceLoc(SourceLoc& value) override; + virtual void handleToken(Token& value) override; + virtual void handleASTNode(NodeBase*& node) override; + virtual void handleASTNodeContents(NodeBase* node) override; - template<typename K, typename V> - void encode(OrderedDictionary<K, V> const& dictionary) - { - Encoder::WithArray withArray(encoder); - for (auto p : dictionary) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(p.key); - encode(p.value); - } - } + void _writeImportedModule(ModuleDecl* moduleDecl); + void _writeImportedDecl(Decl* decl, ModuleDecl* importedFromModuleDecl); - template<typename K, typename V> - void encode(Dictionary<K, V> const& dictionary) + ModuleDecl* _findModuleForDecl(Decl* decl) { - Encoder::WithArray withArray(encoder); - for (auto p : dictionary) + for (auto d = decl; d; d = d->parentDecl) { - Encoder::WithKeyValuePair withPair(encoder); - encode(p.first); - encode(p.second); + if (auto m = as<ModuleDecl>(d)) + return m; } + return nullptr; } - template<typename T> - void encode(T const& value) + ModuleDecl* _findModuleDeclWasImportedFrom(Decl* decl) { - encodeValue(value); - } - - // for each class of node, we generate - // code to recursively serialize each - // of its fields. - -#if 0 // FIDDLE TEMPLATE: -%for _,T in ipairs(Slang.NodeBase.subclasses) do - void _encodeDataOf($T* obj) - { -%if T.directSuperClass then - _encodeDataOf(static_cast<$(T.directSuperClass)*>(obj)); -%end -%for _,f in ipairs(T.directFields) do - encode(obj->$f); -%end + auto declModule = _findModuleForDecl(decl); + if (declModule == nullptr) + return nullptr; + if (declModule == _module) + return nullptr; + return declModule; } -%end -#else // FIDDLE OUTPUT: -#define FIDDLE_GENERATED_OUTPUT_ID 0 -#include "slang-serialize-ast.cpp.fiddle" -#endif // FIDDLE END }; -void writeSerializedModuleAST( - Encoder* encoder, - ModuleDecl* moduleDecl, - SerialSourceLocWriter* sourceLocWriter) -{ - Encoder::WithObject withObject(encoder); - - // TODO: we should have a more careful pass here, - // where we only encode the public declarations - // - - ASTEncodingContext context(encoder, moduleDecl, sourceLocWriter); - context.getDeclID(moduleDecl); - context.flush(); -} - -struct ASTDecodingContext +struct ASTDecodingContext : ASTSerializerImpl { public: ASTDecodingContext( @@ -673,284 +532,60 @@ public: : _linkage(linkage) , _astBuilder(astBuilder) , _sink(sink) - , _baseChunk(as<RIFF::ListChunk>(baseChunk)) , _sourceLocReader(sourceLocReader) , _requestingSourceLoc(requestingSourceLoc) + , _riffReader(baseChunk) { } +private: Linkage* _linkage = nullptr; + ASTBuilder* _astBuilder = nullptr; DiagnosticSink* _sink = nullptr; SerialSourceLocReader* _sourceLocReader = nullptr; SourceLoc _requestingSourceLoc; + RIFFSerialReader _riffReader; - SlangResult decodeAll() - { - auto cursor = _baseChunk->getChildren().begin(); - - // There are a few different top-level chunks that - // hold different arrays that we need in order - // to decode the entire module hierarchy. - // - // Basically, these lists correspond to the kinds - // of nodes in the AST hierarchy for which back-references - // are allowed (all other nodes should, barring - // weird corner cases, form a single tree-structured - // ownership hierarchy, rooted at the `ModuleDecl`. - // - - // First there is the list that actually encodes - // for the declarations in the module, including - // the `ModuleDecl` itself, which should be the - // first entry in the list. - // - auto declChunk = *cursor; - ++cursor; - - // Next there is a list of all the declarations - // referenced inside of the module that need to - // be imported in from outside. - // - auto importedDeclChunk = *cursor; - ++cursor; + virtual ISerializerImpl* getBaseSerializer() override { return &_riffReader; } - // Then there are all the `Val`-derived nodes that - // are needed by the module, which will need to be - // deduplicated so that they are unique within the - // current compilation context. - // - auto valChunk = *cursor; - ++cursor; + virtual void handleName(Name*& value) override; + virtual void handleSourceLoc(SourceLoc& value) override; + virtual void handleToken(Token& value) override; + virtual void handleASTNode(NodeBase*& outNode) override; + virtual void handleASTNodeContents(NodeBase* node) override; - // The process of decoding the module is then spread - // over a number of steps. - // - // The first step is to process all of the imported - // declarations, so that other nodes can refer to - // them. - // - SLANG_RETURN_ON_FAIL(decodeImportedDecls(importedDeclChunk)); - - // Next we process the declarations that are within - // the module itself, first creating an "empty shell" - // of each declaration that has the right size in - // memory (and the right `ASTNodeType` tag), so that - // we can wire up references to it (including circular - // references)... so long as nothing here tries to - // look *inside* the empty shell along the way. - // - SLANG_RETURN_ON_FAIL(createEmptyShells(declChunk)); - - // Once all the `Decl`s that might be needed have - // been allocated, we can process all the `Val`s - // that might reference those`Decl`s (and one another). - // - // The nature of the `Val` representation ensures - // that there cannot be cirularities in the references - // between `Val`s, and the encoding process will have - // sorted the entries so that a `Val` only ever appears - // *after* its operands. - // - SLANG_RETURN_ON_FAIL(decodeVals(valChunk)); + ModuleDecl* _readImportedModule(); + NodeBase* _readImportedDecl(); - // Once all the back-reference-able objects have been - // instantiated in memory, we can go back through the - // `Decl`s in the module and fill in those empty shells. - // - SLANG_RETURN_ON_FAIL(fillEmptyShells(declChunk)); - - // As a final pass, we perform any special cleanup actions - // that might be required to make the output valid for consumers. - // - // For example, this is where we set the `DeclCheckState` of everything - // we are loading to reflect the fact that everything we deserialize - // is (supposed to be) fully cheked. - // - SLANG_RETURN_ON_FAIL(cleanUpNodes()); - - - return SLANG_OK; - } - - typedef Int DeclID; - Decl* getDeclByID(DeclID id) + void _cleanUpASTNode(NodeBase* node) { - if (id >= 0) - { - return _decls[id]; - } - else + if (auto expr = as<Expr>(node)) { - return _importedDecls[~id]; + expr->checked = true; } - } - -private: - struct UnhandledCase - { - }; - - ASTBuilder* _astBuilder = nullptr; - RIFF::ListChunk const* _baseChunk = nullptr; - - List<Decl*> _decls; - List<Decl*> _importedDecls; - List<Val*> _vals; - - typedef Int ValID; - Val* getValByID(ValID id) { return _vals[id]; } - - SlangResult decodeImportedDecls(RIFF::Chunk const* importedDeclChunk) - { - Decoder decoder(importedDeclChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) + else if (auto decl = as<Decl>(node)) { - Decoder::WithKeyValuePair withPair(decoder); - - Int moduleIndex; - decode(moduleIndex, decoder); + decl->checkState = DeclCheckState::CapabilityChecked; - if (moduleIndex == -1) + if (auto genericDecl = as<GenericDecl>(node)) { - Name* moduleName = nullptr; - decode(moduleName, decoder); - - Decl* importedModule = getImportedModule(moduleName); - _importedDecls.add(importedModule); + _assignGenericParameterIndices(genericDecl); } - else + else if (auto syntaxDecl = as<SyntaxDecl>(node)) { - auto importedFromModuleDecl = as<ModuleDecl>(_importedDecls[moduleIndex]); - auto importedFromModule = importedFromModuleDecl->module; - - String mangledName; - decode(mangledName, decoder); - - auto importedNode = - importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); - auto importedDecl = as<Decl>(importedNode); - _importedDecls.add(importedDecl); + syntaxDecl->parseCallback = &parseSimpleSyntax; + syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); } - } - return SLANG_OK; - } - - ModuleDecl* getImportedModule(Name* moduleName) - { - Module* module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); - if (!module) - { - SLANG_ABORT_COMPILATION("failed to load an imported module during deserialization"); - } - - return module->getModuleDecl(); - } - - SlangResult decodeVals(RIFF::Chunk const* valChunk) - { - Decoder decoder(valChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Val* val = decodeValNode(decoder); - _vals.add(val); - } - return SLANG_OK; - } - - SlangResult createEmptyShells(RIFF::Chunk const* declChunk) - { - Decoder decoder(declChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - ASTNodeType nodeType; - - // Each of the declarations is expected to take - // the form of an object with a first field - // that holds the node type. - // + else if (auto namespaceLikeDecl = as<NamespaceDeclBase>(node)) { - Decoder::WithObject withObject(decoder); - decode(nodeType, decoder); + auto declScope = _astBuilder->create<Scope>(); + declScope->containerDecl = namespaceLikeDecl; + namespaceLikeDecl->ownedScope = declScope; } - - auto emptyShell = createEmptyShell(nodeType); - auto declEmptyShell = as<Decl>(emptyShell); - _decls.add(declEmptyShell); } - - return SLANG_OK; } - Val* decodeValNode(Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - - ASTNodeType nodeType; - decode(nodeType, decoder); - - ValNodeDesc desc; - desc.type = SyntaxClass<NodeBase>(nodeType); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - ValNodeOperand operand; - decode(operand, decoder); - desc.operands.add(operand); - } - - desc.init(); - - auto val = _astBuilder->_getOrCreateImpl(_Move(desc)); - - // Values created during deserialization are - // not expected to ever resolve further, because - // they should be coming from fully checked code. - // - // val->resolve(); - // val->_setUnique(); - - return val; - } - - NodeBase* createEmptyShell(ASTNodeType nodeType) - { - return SyntaxClass<NodeBase>(nodeType).createInstance(_astBuilder); - } - - SlangResult fillEmptyShells(RIFF::Chunk const* declChunk) - { - Index declIndex = 0; - - Decoder decoder(declChunk); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - auto declEmptyShell = _decls[declIndex++]; - decodeASTNodeContent(declEmptyShell, decoder); - } - - return SLANG_OK; - } - - SlangResult cleanUpNodes() - { - for (auto decl : _decls) - { - decl->checkState = DeclCheckState::CapabilityChecked; - } - - return SLANG_OK; - } - - - void assignGenericParameterIndices(GenericDecl* genericDecl) + void _assignGenericParameterIndices(GenericDecl* genericDecl) { int parameterCounter = 0; for (auto m : genericDecl->members) @@ -965,576 +600,314 @@ private: } } } +}; +// +// We are matching up the corresponding `handle*()` operations from the +// `AST{Encoding|Decoding}Context` types here, so that it is easier +// to visually verify that they are serializing the same data with the +// same ordering. +// - void cleanUpASTNode(NodeBase* node) - { - if (auto expr = as<Expr>(node)) - { - expr->checked = true; - } - else if (auto genericDecl = as<GenericDecl>(node)) - { - assignGenericParameterIndices(genericDecl); - } - else if (auto syntaxDecl = as<SyntaxDecl>(node)) - { - syntaxDecl->parseCallback = &parseSimpleSyntax; - syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); - } - else if (auto namespaceLikeDecl = as<NamespaceDeclBase>(node)) - { - auto declScope = _astBuilder->create<Scope>(); - declScope->containerDecl = namespaceLikeDecl; - namespaceLikeDecl->ownedScope = declScope; - } - } - - void decodeASTNodeContent(NodeBase* node, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - - ASTNodeDispatcher<NodeBase, void>::dispatch( - node, - [&](auto n) { _decodeDataOf(n, decoder); }); +// +// AST{Encoding|Decoding}Context::handleName() +// - cleanUpASTNode(node); - } +void ASTEncodingContext::handleName(Name*& value) +{ + serialize(ASTSerializer(this), value->text); +} - DeclID decodeDeclID(Decoder& decoder) - { - DeclID result = decoder.decode<DeclID>(); - return result; - } +void ASTDecodingContext::handleName(Name*& value) +{ + String text; + serialize(ASTSerializer(this), text); + value = _astBuilder->getNamePool()->getName(text); +} - ValID decodeValID(Decoder& decoder) - { - ValID result = decoder.decode<ValID>(); - return result; - } +// +// AST{Encoding|Decoding}Context::handleSourceLoc() +// - template<typename T> - void decodeASTNode(T*& node, Decoder& decoder) +void ASTEncodingContext::handleSourceLoc(SourceLoc& value) +{ + ASTSerializer serializer(this); + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (_sourceLocWriter != nullptr) { - ASTNodeType nodeType; - auto saved = decoder.getCursor(); - { - Decoder::WithObject withObject(decoder); - decode(nodeType, decoder); - } - decoder.setCursor(saved); - - auto shell = createEmptyShell(nodeType); - decodeASTNodeContent(shell, decoder); - - node = as<T>(shell); + auto rawValue = _sourceLocWriter->addSourceLoc(value); + serialize(serializer, rawValue); } +} - void decodePtr(Name*& name, Decoder& decoder, Name*) +void ASTDecodingContext::handleSourceLoc(SourceLoc& value) +{ + ASTSerializer serializer(this); + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (hasElements(serializer)) { - String text; - decode(text, decoder); - - name = _astBuilder->getNamePool()->getName(text); - } + SerialSourceLocData::SourceLoc rawValue; + serialize(serializer, rawValue); - void decodePtr(DeclAssociationList*& outList, Decoder& decoder, DeclAssociationList*) - { - // Mirroring the encoding logic, we decode this - // as a list of key-value pairs. - // - auto list = RefPtr(new DeclAssociationList()); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) + if (_sourceLocReader) { - auto association = RefPtr(new DeclAssociation()); - - Decoder::WithKeyValuePair withPair(decoder); - decode(association->kind, decoder); - decode(association->decl, decoder); - - list->associations.add(association); + value = _sourceLocReader->getSourceLoc(rawValue); } - - outList = list.detach(); } +} - void decodePtr(DiagnosticInfo const*& info, Decoder& decoder, DiagnosticInfo const*) - { - Int id; - decode(id, decoder); - info = getDiagnosticsLookup()->getDiagnosticById(id); - } - - void decodePtr(MarkupEntry*& markupEntry, Decoder&, MarkupEntry*) - { - // TODO: is this case needed? - markupEntry = nullptr; - } - - void decodePtr(CandidateExtensionList*& list, Decoder& decoder, CandidateExtensionList*) - { - auto result = RefPtr(new CandidateExtensionList()); - decode(result->candidateExtensions, decoder); - list = result.detach(); - } - - void decodePtr(WitnessTable*& witnessTable, Decoder& decoder, WitnessTable*) - { - Decoder::WithObject withObject(decoder); - auto wt = RefPtr(new WitnessTable()); - decode(wt->baseType, decoder); - decode(wt->witnessedType, decoder); - decode(wt->isExtern, decoder); - decode(wt->m_requirementDictionary, decoder); - witnessTable = wt.detach(); - } - - void decodeValue(RequirementWitness& witness, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decodeEnum(witness.m_flavor, decoder); - switch (witness.m_flavor) - { - case RequirementWitness::Flavor::none: - break; - - case RequirementWitness::Flavor::declRef: - decode(witness.m_declRef, decoder); - break; - - case RequirementWitness::Flavor::val: - decode(witness.m_val, decoder); - break; +// +// AST{Encoding|Decoding}Context::handleToken() +// - case RequirementWitness::Flavor::witnessTable: - { - RefPtr<WitnessTable> object; - decode(object, decoder); - witness.m_obj = object; - } - break; - } - } +void ASTDecodingContext::handleToken(Token& value) +{ + ASTSerializer serializer(this); - template<typename T> - void decodePtr(T*& node, Decoder& decoder, Val*) - { - ValID id = decodeValID(decoder); - node = static_cast<T*>(getValByID(id)); - } + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.loc); - template<typename T> - void decodePtr(T*& node, Decoder& decoder, Decl*) - { - DeclID id = decodeDeclID(decoder); - node = static_cast<T*>(getDeclByID(id)); - } + serialize(serializer, value.flags); - template<typename T> - void decodePtr(T*& node, Decoder& decoder, DeclBase*) { - // This case is a bit of a hack. We need - // to identify whether we are looking at - // an indirection to a `Decl` (which would - // be serialized as an integer `DeclID`), - // or something else derived from `DeclBase`. - // - switch (decoder.getTag()) + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (hasElements(serializer)) { - default: - decodeASTNode(node, decoder); - break; - - case SerialBinary::kInt32FourCC: - case SerialBinary::kInt64FourCC: - case SerialBinary::kUInt32FourCC: - case SerialBinary::kUInt64FourCC: - { - DeclID id = decodeDeclID(decoder); - node = static_cast<T*>(getDeclByID(id)); - } - break; - } - } - - template<typename T> - void decodePtr(T*& node, Decoder& decoder, NodeBase*) - { - decodeASTNode(node, decoder); - } + String content; + serialize(serializer, content); - - void decodeValue(UnhandledCase, Decoder& decoder); - - void decodeValue(String& value, Decoder& decoder) { value = decoder.decodeString(); } - - void decodeValue(Token& value, Decoder& decoder) - { - decode(value.type, decoder); - decode(value.flags, decoder); - decode(value.loc, decoder); - if (decoder.decodeNull()) - { - } - else - { - Name* name = nullptr; - decode(name, decoder); + // An important note here is that we cannot just + // call `value.setContent(...)` and pass in an + // `UnownedStringSlice` of `content`, because the + // `Token` will not take ownership of its own + // textual content. + // + // Instead, we need to get the text we just loaded + // into something that the `Token` can refer info, + // and the easiest way to accomplish that is to + // represent the text using a `Name`. + // + Name* name = _astBuilder->getNamePool()->getName(content); value.setName(name); } } +} - void decodeValue(NameLoc& value, Decoder& decoder) { decode(value.name, decoder); } - - void decodeValue(SemanticVersion& value, Decoder& decoder) - { - SemanticVersion::RawValue rawValue = decoder.decode<SemanticVersion::RawValue>(); - value.setRawValue(rawValue); - } - - void decodeValue(CapabilitySet& value, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - CapabilityTargetSet targetSet; - decode(targetSet, decoder); - value.getCapabilityTargetSets()[targetSet.target] = targetSet; - } - } - - void decodeValue(CapabilityTargetSet& value, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decode(value.target, decoder); +void ASTEncodingContext::handleToken(Token& value) +{ + ASTSerializer serializer(this); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - CapabilityStageSet stageSet; - decode(stageSet, decoder); - value.shaderStageSets[stageSet.stage] = stageSet; - } - } + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.loc); - void decodeValue(CapabilityStageSet& value, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decode(value.stage, decoder); - decode(value.atomSet, decoder); - } + TokenFlags flags = TokenFlags(value.flags & ~TokenFlag::Name); + serialize(serializer, flags); - void decodeValue(CapabilityAtomSet& value, Decoder& decoder) { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - CapabilityAtom atom; - decode(atom, decoder); - value.add(UInt(atom)); - } - } - - template<typename T> - void decodeValue(std::optional<T>& outValue, Decoder& decoder) - { - if (decoder.decodeNull()) - { - outValue.reset(); - } - else + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (value.hasContent()) { - T value; - decode(value, decoder); - outValue = value; + String content = value.getContent(); + serialize(serializer, content); } } +} - void decodeValue(SyntaxClass<NodeBase>& syntaxClass, Decoder& decoder) - { - ASTNodeType nodeType; - decode(nodeType, decoder); - syntaxClass = SyntaxClass<NodeBase>(nodeType); - } - - template<typename T> - void decodeValue(DeclRef<T>& declRef, Decoder& decoder) - { - decode(declRef.declRefBase, decoder); - } +// +// AST{Encoding|Decoding}Context::handleASTNode() +// - void decodeValue(ValNodeOperand& value, Decoder& decoder) +void ASTEncodingContext::handleASTNode(NodeBase*& node) +{ + if (auto decl = as<Decl>(node)) { - Decoder::WithKeyValuePair withPair(decoder); - - decodeEnum(value.kind, decoder); - switch (value.kind) + if (auto importedFromModule = _findModuleDeclWasImportedFrom(decl)) { - case ValNodeOperandKind::ConstantValue: - decode(value.values.intOperand, decoder); - break; - - case ValNodeOperandKind::ValNode: + if (decl == importedFromModule) { - Val* val = nullptr; - decode(val, decoder); - value.values.nodeOperand = val; + _writeImportedModule(importedFromModule); + return; } - break; - - case ValNodeOperandKind::ASTNode: + else { - Decl* decl = nullptr; - decode(decl, decoder); - value.values.nodeOperand = decl; + _writeImportedDecl(decl, importedFromModule); + return; } - break; } } - void decodeValue(TypeExp& value, Decoder& decoder) { decode(value.type, decoder); } + ASTSerializer serializer(this); - void decodeValue(QualType& value, Decoder& decoder) + if (auto val = as<Val>(node)) { - Decoder::WithObject withObject(decoder); - decode(value.type, decoder); - decode(value.isLeftValue, decoder); - decode(value.hasReadOnlyOnTarget, decoder); - decode(value.isWriteOnly, decoder); - } + val = val->resolve(); - void decodeValue(MatrixCoord& value, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - decode(value.row, decoder); - decode(value.col, decoder); + // On the reading side of things, sublcasses of `Val` + // are deduplicated as part of creation, and will read the + // operands out immediately, so we mirror that approach + // on the writing side to make sure the code is consistent. + // + serialize(serializer, val->astNodeType); + serialize(serializer, val->m_operands); } - - void decodeValue(SPIRVAsmOperand::Flavor& value, Decoder& decoder) + else { - decodeEnum(value, decoder); + serialize(serializer, node->astNodeType); + deferSerializeObjectContents(serializer, node); } +} - void decodeValue(SPIRVAsmOperand& value, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - decode(value.flavor, decoder); - decode(value.token, decoder); - decode(value.expr, decoder); - decode(value.bitwiseOrWith, decoder); - decode(value.knownValue, decoder); - decode(value.wrapInId, decoder); - decode(value.type, decoder); - } +void ASTDecodingContext::handleASTNode(NodeBase*& outNode) +{ + ASTSerializer serializer(this); - void decodeValue(SPIRVAsmInst& value, Decoder& decoder) + ASTNodeType typeTag; + serialize(serializer, typeTag); + switch (_getPseudoASTNodeType(typeTag)) { - Decoder::WithObject withObject(decoder); - decode(value.opcode, decoder); - decode(value.operands, decoder); - } + default: + break; + case PseudoASTNodeType::ImportedModule: + outNode = _readImportedModule(); + return; - template<typename T> - void decodeEnum(T& value, Decoder& decoder) - { - value = T(decoder.decode<Int32>()); + case PseudoASTNodeType::ImportedDecl: + outNode = _readImportedDecl(); + return; } - template<typename T> - void decodeSimpleValue(T& value, Decoder& decoder) + auto syntaxClass = SyntaxClass<NodeBase>(typeTag); + if (syntaxClass.isSubClassOf<Val>()) { - value = decoder.decode<T>(); - } + // Subclasses of `Val` are deduplicated as part + // of creation, so we need to read in their + // operands before we can create them, rather + // than allocating the object up front and + // then deserializing its content into it later. - void decodeValue(bool& value, Decoder& decoder) { value = decoder.decodeBool(); } - void decodeValue(Int32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(Int64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(UInt32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(UInt64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(float& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(double& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + ValNodeDesc desc; + desc.type = syntaxClass; + serialize(serializer, desc.operands); - void decodeValue(uint8_t& value, Decoder& decoder) - { - value = uint8_t(decoder.decode<UInt32>()); - } + desc.init(); - void decodeValue(DeclVisibility& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(BaseType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(BuiltinRequirementKind& value, Decoder& decoder) - { - decodeEnum(value, decoder); - } - void decodeValue(ASTNodeType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(ImageFormat& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TypeTag& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TryClauseType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(CapabilityAtom& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(PreferRecomputeAttribute::SideEffectBehavior& value, Decoder& decoder) - { - decodeEnum(value, decoder); + auto node = _astBuilder->_getOrCreateImpl(std::move(desc)); + outNode = node; } - void decodeValue(LogicOperatorShortCircuitExpr::Flavor& value, Decoder& decoder) + else { - decodeEnum(value, decoder); - } - void decodeValue(TreatAsDifferentiableExpr::Flavor& value, Decoder& decoder) - { - decodeEnum(value, decoder); - } - void decodeValue(DeclAssociationKind& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TokenType& value, Decoder& decoder) { decodeEnum(value, decoder); } + auto node = syntaxClass.createInstance(_astBuilder); + outNode = node; - - void decodeValue(SourceLoc& value, Decoder& decoder) - { - if (!decoder.decodeNull()) - { - SerialSourceLocData::SourceLoc intermediate; - decoder.decode(intermediate); - - if (_sourceLocReader) - { - auto sourceLoc = _sourceLocReader->getSourceLoc(intermediate); - value = sourceLoc; - } - } + deferSerializeObjectContents(serializer, node); } +} - template<typename T> - void decodeValue(T*& ptr, Decoder& decoder) - { - if (decoder.decodeNull()) - ptr = nullptr; - else - decodePtr(ptr, decoder, (T*)nullptr); - } +// +// AST{Encoding|Decoding}Context::handleASTNodeContents() +// - template<typename T> - void decodeValue(RefPtr<T>& ptr, Decoder& decoder) - { - if (decoder.decodeNull()) - ptr = nullptr; - else - { - // Hi Future Tess, - // - // The next step here is decoding logic for `WitnessTable`s. - // +void ASTEncodingContext::handleASTNodeContents(NodeBase* node) +{ + ASTSerializer serializer(this); + serializeASTNodeContents(serializer, node); +} - decodePtr(*ptr.writeRef(), decoder, (T*)nullptr); - } - } +void ASTDecodingContext::handleASTNodeContents(NodeBase* node) +{ + ASTSerializer serializer(this); + serializeASTNodeContents(serializer, node); - void decodeValue(Modifiers& modifiers, Decoder& decoder) - { - Modifier** link = &modifiers.first; + _cleanUpASTNode(node); +} - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Modifier* modifier = nullptr; - decode(modifier, decoder); +// +// AST{Encoding|Decoding}Context::_{write|read}ImportedModule() +// - *link = modifier; - link = &modifier->next; - } - } +void ASTEncodingContext::_writeImportedModule(ModuleDecl* moduleDecl) +{ + ASTNodeType type = _getAsASTNodeType(PseudoASTNodeType::ImportedModule); + auto moduleName = moduleDecl->getName(); - template<typename T, int N> - void decodeValue(ShortList<T, N>& array, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - T element; - decode(element, decoder); - array.add(element); - } - } + ASTSerializer serializer(this); + serialize(serializer, type); + serialize(serializer, moduleName); +} +ModuleDecl* ASTDecodingContext::_readImportedModule() +{ + ASTSerializer serializer(this); - template<typename T> - void decode(List<T>& array, Decoder& decoder) + Name* moduleName = nullptr; + serialize(serializer, moduleName); + auto module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); + if (!module) { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - T element; - decode(element, decoder); - array.add(element); - } + SLANG_ABORT_COMPILATION("failed to load an imported module during AST deserialization"); } + return module->getModuleDecl(); +} - template<typename T, size_t N> - void decode(T (&array)[N], Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - for (auto& element : array) - { - decode(element, decoder); - } - } +// +// AST{Encoding|Decoding}Context::_{write|read}ImportedModule() +// - template<typename K, typename V> - void decode(OrderedDictionary<K, V>& dictionary, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Decoder::WithKeyValuePair withPair(decoder); +void ASTEncodingContext::_writeImportedDecl(Decl* decl, ModuleDecl* importedFromModuleDecl) +{ + ASTNodeType type = _getAsASTNodeType(PseudoASTNodeType::ImportedDecl); + auto mangledName = getMangledName(getCurrentASTBuilder(), decl); - K key; - V value; - decode(key, decoder); - decode(value, decoder); + ASTSerializer serializer(this); + serialize(serializer, type); + serialize(serializer, importedFromModuleDecl); + serialize(serializer, mangledName); +} - dictionary.add(key, value); - } - } +NodeBase* ASTDecodingContext::_readImportedDecl() +{ + ASTSerializer serializer(this); - template<typename K, typename V> - void decode(Dictionary<K, V>& dictionary, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Decoder::WithKeyValuePair withPair(decoder); + ModuleDecl* importedFromModuleDecl = nullptr; + String mangledName; - K key; - V value; - decode(key, decoder); - decode(value, decoder); + serialize(serializer, importedFromModuleDecl); + serialize(serializer, mangledName); - dictionary.add(key, value); - } + auto importedFromModule = importedFromModuleDecl->module; + if (!importedFromModule) + { + return nullptr; } - template<typename T> - void decode(T& outValue, Decoder& decoder) + auto importedDecl = + importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); + if (!importedDecl) { - decodeValue(outValue, decoder); + SLANG_ABORT_COMPILATION( + "failed to load an imported declaration during AST deserialization"); } + return importedDecl; +} -#if 0 // FIDDLE TEMPLATE: -%for _,T in ipairs(Slang.NodeBase.subclasses) do - void _decodeDataOf($T* obj, Decoder& decoder) - { -% if T.directSuperClass then - _decodeDataOf(static_cast<$(T.directSuperClass)*>(obj), decoder); -% end -% for _,f in ipairs(T.directFields) do - decode(obj->$f, decoder); -% end - } -%end -#else // FIDDLE OUTPUT: -#define FIDDLE_GENERATED_OUTPUT_ID 1 -#include "slang-serialize-ast.cpp.fiddle" -#endif // FIDDLE END -}; +// +// {write|read}SerializedModuleAST() +// + +void writeSerializedModuleAST( + RIFF::BuildCursor& cursor, + ModuleDecl* moduleDecl, + SerialSourceLocWriter* sourceLocWriter) +{ + // TODO: we might want to have a more careful pass here, + // where we only encode the public declarations. + + ASTEncodingContext context(cursor, moduleDecl, sourceLocWriter); + serialize(ASTSerializer(&context), moduleDecl); +} ModuleDecl* readSerializedModuleAST( Linkage* linkage, @@ -1546,9 +919,10 @@ ModuleDecl* readSerializedModuleAST( { ASTDecodingContext context(linkage, astBuilder, sink, chunk, sourceLocReader, requestingSourceLoc); - context.decodeAll(); - auto node = context.getDeclByID(0); - auto moduleDecl = as<ModuleDecl>(node); + + ModuleDecl* moduleDecl = nullptr; + serialize(ASTSerializer(&context), moduleDecl); return moduleDecl; } + } // namespace Slang |
