diff options
Diffstat (limited to 'source/slang/slang-serialize-ast.cpp')
| -rw-r--r-- | source/slang/slang-serialize-ast.cpp | 1932 |
1 files changed, 653 insertions, 1279 deletions
diff --git a/source/slang/slang-serialize-ast.cpp b/source/slang/slang-serialize-ast.cpp index 1ef532ad1..9a61dbc5a 100644 --- a/source/slang/slang-serialize-ast.cpp +++ b/source/slang/slang-serialize-ast.cpp @@ -5,6 +5,7 @@ #include "slang-compiler.h" #include "slang-diagnostics.h" #include "slang-mangle.h" +#include "slang-serialize-riff.h" namespace Slang { @@ -13,654 +14,512 @@ namespace Slang // NodeBase* parseSimpleSyntax(Parser* parser, void* userData); +// +// Many of the types used in the AST can be serialized using +// just the `Serializer` type, so we will handle all of those first. +// -struct ASTEncodingContext +void serialize(Serializer const& serializer, ASTNodeType& value) { -private: - Encoder* encoder; - struct UnhandledCase - { - }; - - typedef Int DeclID; - Dictionary<Decl*, DeclID> mapDeclToID; - List<Decl*> decls; - - struct ImportedDeclInfo - { - Int moduleIndex = -1; - Decl* decl; - }; - List<ImportedDeclInfo> importedDecls; - - typedef Int ValID; - Dictionary<Val*, ValID> mapValToID; - List<Val*> vals; - - ModuleDecl* _module = nullptr; - - SerialSourceLocWriter* _sourceLocWriter = nullptr; + serializeEnum(serializer, value); +} -public: - ASTEncodingContext(Encoder* encoder, ModuleDecl* module, SerialSourceLocWriter* sourceLocWriter) - : encoder(encoder), _module(module), _sourceLocWriter(sourceLocWriter) - { - } +void serialize(Serializer const& serializer, TypeTag& value) +{ + serializeEnum(serializer, value); +} - template<typename T> - void encodeASTNodeContent(T* node) - { - Encoder::WithObject withObject(encoder); +void serialize(Serializer const& serializer, BaseType& value) +{ + serializeEnum(serializer, value); +} - ASTNodeDispatcher<T, void>::dispatch(node, [&](auto n) { _encodeDataOf(n); }); - } +void serialize(Serializer const& serializer, TryClauseType& value) +{ + serializeEnum(serializer, value); +} - void flush() - { - auto containerChunk = encoder->getRIFFChunk(); +void serialize(Serializer const& serializer, DeclVisibility& value) +{ + serializeEnum(serializer, value); +} - RIFF::ChunkBuilder* declChunk = nullptr; - RIFF::ChunkBuilder* importedDeclChunk = nullptr; - RIFF::ChunkBuilder* valChunk = nullptr; - { - Encoder::WithArray withList(encoder); - declChunk = encoder->getRIFFChunk(); - } - { - Encoder::WithArray withList(encoder); - importedDeclChunk = encoder->getRIFFChunk(); - } - { - Encoder::WithArray withList(encoder); - valChunk = encoder->getRIFFChunk(); - } - Int declIndex = 0; - Int importedDeclIndex = 0; - Int valIndex = 0; +void serialize(Serializer const& serializer, BuiltinRequirementKind& value) +{ + serializeEnum(serializer, value); +} - bool done = false; - do - { - done = true; - while (declIndex < decls.getCount()) - { - done = false; - encoder->setRIFFChunk(declChunk); - encodeASTNodeContent(decls[declIndex++]); - } - while (importedDeclIndex < importedDecls.getCount()) - { - done = false; - encoder->setRIFFChunk(importedDeclChunk); - encodeImportedDecl(importedDecls[importedDeclIndex++]); - } - while (valIndex < vals.getCount()) - { - done = false; - encoder->setRIFFChunk(valChunk); - encodeASTNodeContent(vals[valIndex++]); - } - } while (!done); +void serialize(Serializer const& serializer, ImageFormat& value) +{ + serializeEnum(serializer, value); +} - encoder->setRIFFChunk(containerChunk); - } +void serialize(Serializer const& serializer, PreferRecomputeAttribute::SideEffectBehavior& value) +{ + serializeEnum(serializer, value); +} - ModuleDecl* findModuleForDecl(Decl* decl) - { - for (auto d = decl; d; d = d->parentDecl) - { - if (auto m = as<ModuleDecl>(d)) - return m; - } - return nullptr; - } +void serialize(Serializer const& serializer, TreatAsDifferentiableExpr::Flavor& value) +{ + serializeEnum(serializer, value); +} - ModuleDecl* findModuleDeclWasImportedFrom(Decl* decl) - { - auto declModule = findModuleForDecl(decl); - if (declModule == nullptr) - return nullptr; - if (declModule == _module) - return nullptr; - return declModule; - } +void serialize(Serializer const& serializer, LogicOperatorShortCircuitExpr::Flavor& value) +{ + serializeEnum(serializer, value); +} - DeclID getDeclID(Decl* decl) - { - SLANG_ASSERT(decl != nullptr); +void serialize(Serializer const& serializer, RequirementWitness::Flavor& value) +{ + serializeEnum(serializer, value); +} - if (auto found = mapDeclToID.tryGetValue(decl)) - return *found; +void serialize(Serializer const& serializer, CapabilityAtom& value) +{ + serializeEnum(serializer, value); +} - // We need to detect whether the declaration is an - // imported one, or one from this module itself. - // - // Imported declarations need to be handled very - // differently, since they'll involve resolving - // references to those other modules, and the - // declarations within them. - // - if (auto importedFromModule = findModuleDeclWasImportedFrom(decl)) - { - DeclID importedFromModuleDeclID = 0; - if (decl != importedFromModule) - { - importedFromModuleDeclID = getDeclID(importedFromModule); - } +void serialize(Serializer const& serializer, DeclAssociationKind& value) +{ + serializeEnum(serializer, value); +} - DeclID id = ~importedDecls.getCount(); - mapDeclToID.add(decl, id); +void serialize(Serializer const& serializer, TokenType& value) +{ + serializeEnum(serializer, value); +} - ImportedDeclInfo info; - info.moduleIndex = ~importedFromModuleDeclID; - info.decl = decl; - importedDecls.add(info); +void serialize(Serializer const& serializer, ValNodeOperandKind& value) +{ + serializeEnum(serializer, value); +} - return id; - } - else - { - DeclID id = decls.getCount(); - decls.add(decl); - mapDeclToID.add(decl, id); +void serialize(Serializer const& serializer, SPIRVAsmOperand::Flavor& value) +{ + serializeEnum(serializer, value); +} - return id; - } - } +void serialize(Serializer const& serializer, MatrixCoord& value) +{ + SLANG_SCOPED_SERIALIZER_TUPLE(serializer); + serialize(serializer, value.row); + serialize(serializer, value.col); +} - void encodePtr(Decl* decl) +void serializePtr(Serializer const& serializer, DiagnosticInfo const*& value, DiagnosticInfo const*) +{ + Int32 id = 0; + if (isWriting(serializer)) { - DeclID id = getDeclID(decl); - encoder->encode(id); + id = value->id; + serialize(serializer, id); } - - ValID getValID(Val* val) + else { - SLANG_ASSERT(val != nullptr); - - if (auto found = mapValToID.tryGetValue(val)) - return *found; - - // In order to ensure that values can be fully constructed - // from the get-go (so that they will get cached correctly), - // we conspire to ensure that every value is preceded by - // all of its operands. - // - for (auto operand : val->m_operands) - { - switch (operand.kind) - { - default: - break; - - case ValNodeOperandKind::ValNode: - if (auto operandNode = operand.values.nodeOperand) - { - SLANG_ASSERT(as<Val>(operandNode)); - getValID(static_cast<Val*>(operandNode)); - } - break; - - case ValNodeOperandKind::ASTNode: - if (auto operandNode = operand.values.nodeOperand) - { - SLANG_ASSERT(as<Decl>(operandNode)); - getDeclID(static_cast<Decl*>(operandNode)); - } - break; - } - } - auto resolved = val->resolve(); - if (resolved != val) - { - getValID(resolved); - } - - ValID id = vals.getCount(); - vals.add(val); - mapValToID.add(val, id); - return id; + serialize(serializer, id); + value = getDiagnosticsLookup()->getDiagnosticById(id); } +} - void encodePtr(Val* val) +void serialize(Serializer const& serializer, SemanticVersion& value) +{ + auto raw = value.getRawValue(); + serialize(serializer, raw); + value = SemanticVersion::fromRaw(raw); +} + +void serialize(Serializer const& serializer, SyntaxClass<NodeBase>& value) +{ + ASTNodeType raw; + if (isWriting(serializer)) { - ValID id = getValID(val); - encoder->encode(id); + raw = value.getTag(); } - - void encodeImportedDecl(ImportedDeclInfo const& info) + serialize(serializer, raw); + if (isReading(serializer)) { - Encoder::WithKeyValuePair withPair(encoder); - encode(info.moduleIndex); - auto decl = info.decl; - if (auto importedModuleDecl = as<ModuleDecl>(decl)) - { - SLANG_ASSERT(info.moduleIndex == -1); - encode(importedModuleDecl->getName()); - } - else - { - auto mangledName = getMangledName(getCurrentASTBuilder(), decl); - encode(mangledName); - } + value = SyntaxClass<NodeBase>(raw); } +} - void encodePtr(Modifier* modifier) { encodeASTNodeContent(modifier); } - void encodePtr(Expr* expr) { encodeASTNodeContent(expr); } - void encodePtr(Stmt* stmt) { encodeASTNodeContent(stmt); } +// +// Many types in the AST need additional context (beyond +// what the `Serializer` has) in order to serialize +// themselves or their members. +// +// We define a custom serializer interface to capture +// the cases that can't be handled by a `Serializer` +// alone. +// - void encodePtr(Name* name) { encode(name->text); } +/// Interface for AST serialization +struct ASTSerializerImpl +{ +public: + virtual void handleASTNode(NodeBase*& value) = 0; + virtual void handleASTNodeContents(NodeBase* value) = 0; + virtual void handleName(Name*& value) = 0; + virtual void handleSourceLoc(SourceLoc& value) = 0; + virtual void handleToken(Token& value) = 0; - void encodePtr(MarkupEntry* entry) - { - // TODO: is this case needed? - SLANG_UNUSED(entry); - } + // Note that this type does *not* inherit from `ISerializerImpl`. + // + // We want to decouple the AST-specific context information + // from the lower-level details of the serialization format. + // + // Instead of using inheritance, we expect that any + // `ASTSerializerImpl` will aggregate a lower-level + // serializer, and the interface exposes access to + // that base serializer implementation. - void encodePtr(DeclAssociationList* list) - { - // We serialize this as if it were a simple list - // of key-value pairs because... well... that's - // what it amounts to in practice. - // - Encoder::WithArray withArray(encoder); - for (auto association : list->associations) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(association->kind); - encode(association->decl); - } - } + virtual ISerializerImpl* getBaseSerializer() = 0; +}; - void encodePtr(CandidateExtensionList* list) { encode(list->candidateExtensions); } +/// Specialization of `Serializer_` for AST serialization. +template<> +struct Serializer_<ASTSerializerImpl> : SerializerBase<ASTSerializerImpl> +{ +public: + using SerializerBase::SerializerBase; - void encodePtr(WitnessTable* witnessTable) - { - Encoder::WithObject withObject(encoder); - encode(witnessTable->baseType); - encode(witnessTable->witnessedType); - encode(witnessTable->isExtern); - - // TODO(tfoley): In theory we should be able to streamline - // this so that we only encode the requirements that we - // absolutely need to (which basically amounts to `associatedtype` - // requirements where the satisfying type is part of the public - // API of the type). - // - encode(witnessTable->m_requirementDictionary); - } + // + // In order to allow an `ASTSerializer` to be used with + // functions that expect an ordinary `Serializer`, we + // implement an implicit conversion operator. + // - void encodeValue(RequirementWitness const& witness) - { - Encoder::WithKeyValuePair withPair(encoder); - encodeEnum(witness.m_flavor); - switch (witness.m_flavor) - { - case RequirementWitness::Flavor::none: - break; + operator Serializer() const { return Serializer(get()->getBaseSerializer()); } +}; - case RequirementWitness::Flavor::declRef: - encode(witness.m_declRef); - break; +/// Context type for AST serialization. +using ASTSerializer = Serializer_<ASTSerializerImpl>; - case RequirementWitness::Flavor::val: - encode(witness.m_val); - break; +template<typename T> +void serializeObject(ASTSerializer const& serializer, T*& value, NodeBase*) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serializer->handleASTNode(*(NodeBase**)&value); +} - case RequirementWitness::Flavor::witnessTable: - encode((WitnessTable*)witness.m_obj.Ptr()); - break; - } - } +void serializeObjectContents(ASTSerializer const& serializer, NodeBase* value, NodeBase*) +{ + serializer->handleASTNodeContents(value); +} - void encodePtr(DiagnosticInfo* info) { encode(Int(info->id)); } +template<typename T> +void serialize(ASTSerializer const& serializer, DeclRef<T>& value) +{ + serialize(serializer, value.declRefBase); +} - void encodePtr(DeclBase* declBase) +void serialize(ASTSerializer const& serializer, SourceLoc& value) +{ + serializer->handleSourceLoc(value); +} + +void serialize(ASTSerializer const& serializer, RequirementWitness& value) +{ + SLANG_SCOPED_SERIALIZER_TAGGED_UNION(serializer); + serialize(serializer, value.m_flavor); + switch (value.m_flavor) { - if (auto decl = as<Decl>(declBase)) - { - encodePtr(decl); - } - else - { - encodeASTNodeContent(declBase); - } - } + case RequirementWitness::Flavor::none: + break; - void encodeValue(UnhandledCase); + case RequirementWitness::Flavor::declRef: + serialize(serializer, value.m_declRef); + break; - void encodeValue(String const& value) { encoder->encode(value); } + case RequirementWitness::Flavor::val: + serialize(serializer, value.m_val); + break; - void encodeValue(Token const& value) - { - encode(value.type); - encode(TokenFlags(value.flags & ~TokenFlag::Name)); - encode(value.loc); - if (value.hasContent()) - encoder->encodeString(value.getContent()); - else - encode(nullptr); + case RequirementWitness::Flavor::witnessTable: + serialize(serializer, value.m_obj); + break; } +} - void encodeValue(NameLoc const& value) { encode(value.name); } - - void encodeValue(SemanticVersion value) { encoder->encode(value.getRawValue()); } +void serialize(ASTSerializer const& serializer, WitnessTable& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.baseType); + serialize(serializer, value.witnessedType); + serialize(serializer, value.isExtern); + + // TODO(tfoley): In theory we should be able to streamline + // this so that we only encode the requirements that we + // absolutely need to (which basically amounts to `associatedtype` + // requirements where the satisfying type is part of the public + // API of the type). + // + serialize(serializer, value.m_requirementDictionary); +} - void encodeValue(CapabilitySet const& value) +void serialize(Serializer const& serializer, CapabilityAtomSet& value) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) { - // While the `CapabilityTargetSets` type is a dictionary, - // in practice each entry already embeds its own key - // (the target atom), so we can encode this as just - // an array of the `CapabilityTargetSet` values. - // - Encoder::WithArray withArray(encoder); - for (auto pair : value.getCapabilityTargetSets()) + for (auto rawAtom : value) { - encode(pair.second); + auto atom = CapabilityAtom(rawAtom); + serialize(serializer, atom); } } - - void encodeValue(CapabilityTargetSet const& value) + else { - Encoder::WithKeyValuePair withPair(encoder); - encode(value.target); - - // Similar to the case for the `CapabilityTargetSets` above, - // each `CapabilityStageSet` already includes the stage atom, - // so we can simply encode the values from the dictionary. - // - Encoder::WithArray withArray(encoder); - for (auto pair : value.shaderStageSets) + while (hasElements(serializer)) { - encode(pair.second); + CapabilityAtom atom; + serialize(serializer, atom); + value.add(UInt(atom)); } } +} - void encodeValue(CapabilityStageSet const& value) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(value.stage); - encode(value.atomSet); - } +void serialize(Serializer const& serializer, CapabilityStageSet& value) +{ + serialize(serializer, value.atomSet); +} - void encodeValue(CapabilityAtomSet const& value) +void serialize(Serializer const& serializer, CapabilityTargetSet& value) +{ + serialize(serializer, value.shaderStageSets); + + // The value for each entry in `shaderStageSets` have + // a `stage` field that is redundant with the key for + // that entry. Rather than serialize the key as part + // of the `CapabilityStageSet` type, we instead copy + // it over from the key to the value in the case where + // we are reading. + // + if (isReading(serializer)) { - Encoder::WithArray withArray(encoder); - for (auto rawAtom : value) - { - encode(CapabilityAtom(rawAtom)); - } + for (auto& p : value.shaderStageSets) + p.second.stage = p.first; } +} - template<typename T> - void encodeValue(std::optional<T> const& value) +void serialize(Serializer const& serializer, CapabilitySet& value) +{ + serialize(serializer, value.getCapabilityTargetSets()); + + // The value for each entry in `getCapabilityTargetSets()` have + // a `target` field that is redundant with the key for + // that entry. Rather than serialize the key as part + // of the `CapabilityTargetSet` type, we instead copy + // it over from the key to the value in the case where + // we are reading. + // + if (isReading(serializer)) { - if (value) - encodeValue(*value); - else - encoder->encode(nullptr); + for (auto& p : value.getCapabilityTargetSets()) + p.second.target = p.first; } +} + +void serialize(ASTSerializer const& serializer, CandidateExtensionList& value) +{ + serialize(serializer, value.candidateExtensions); +} + +void serialize(ASTSerializer const& serializer, DeclAssociation& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.kind); + serialize(serializer, value.decl); +} - void encodeValue(SyntaxClass<NodeBase> const& value) { encode(value.getTag()); } +void serialize(ASTSerializer const& serializer, DeclAssociationList& value) +{ + serialize(serializer, value.associations); +} - template<typename T> - void encodeValue(DeclRef<T> const& value) +void serialize(ASTSerializer const& serializer, Modifiers& value) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) { - encode((DeclRefBase*)value); + for (auto modifier : value) + { + serialize(serializer, modifier); + } } - - void encodeValue(ValNodeOperand value) + else { - Encoder::WithKeyValuePair withPair(encoder); + Modifier** link = &value.first; - encodeEnum(value.kind); - switch (value.kind) + while (hasElements(serializer)) { - case ValNodeOperandKind::ConstantValue: - encode(value.values.intOperand); - break; - - case ValNodeOperandKind::ValNode: - encode(static_cast<Val*>(value.values.nodeOperand)); - break; + Modifier* modifier = nullptr; + serialize(serializer, modifier); - case ValNodeOperandKind::ASTNode: - { - if (auto decl = as<Decl>(value.values.nodeOperand)) - { - encode(decl); - } - else - { - SLANG_UNEXPECTED("AST node operand of `Val` was expected to be a `Decl`"); - } - } - break; + *link = modifier; + link = &modifier->next; } } +} - void encodeValue(TypeExp value) { encode(value.type); } +void serialize(ASTSerializer const& serializer, TypeExp& value) +{ + serialize(serializer, value.type); +} - void encodeValue(QualType value) - { - Encoder::WithObject withObject(encoder); - encode(value.type); - encode(value.isLeftValue); - encode(value.hasReadOnlyOnTarget); - encode(value.isWriteOnly); - } +void serialize(ASTSerializer const& serializer, QualType& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.isLeftValue); + serialize(serializer, value.hasReadOnlyOnTarget); + serialize(serializer, value.isWriteOnly); +} - void encodeValue(MatrixCoord value) - { - Encoder::WithObject withObject(encoder); - encode(value.row); - encode(value.col); - } +void serialize(ASTSerializer const& serializer, Token& value) +{ + serializer->handleToken(value); +} - void encodeValue(SPIRVAsmOperand::Flavor const& value) { encodeEnum(value); } +void serialize(ASTSerializer const& serializer, SPIRVAsmOperand& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.flavor); + serialize(serializer, value.token); + serialize(serializer, value.expr); + serialize(serializer, value.bitwiseOrWith); + serialize(serializer, value.knownValue); + serialize(serializer, value.wrapInId); + serialize(serializer, value.type); +} - void encodeValue(SPIRVAsmOperand const& value) - { - Encoder::WithObject withObject(encoder); - encode(value.flavor); - encode(value.token); - encode(value.expr); - encode(value.bitwiseOrWith); - encode(value.knownValue); - encode(value.wrapInId); - encode(value.type); - } +void serialize(ASTSerializer const& serializer, SPIRVAsmInst& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.opcode); + serialize(serializer, value.operands); +} - void encodeValue(SPIRVAsmInst const& value) +void serialize(ASTSerializer const& serializer, ValNodeOperand& value) +{ + SLANG_SCOPED_SERIALIZER_TAGGED_UNION(serializer); + serialize(serializer, value.kind); + switch (value.kind) { - Encoder::WithObject withObject(encoder); - encode(value.opcode); - encode(value.operands); - } + case ValNodeOperandKind::ConstantValue: + serialize(serializer, value.values.intOperand); + break; - - template<typename T, typename = std::enable_if_t<std::is_same_v<T, bool>>> - void encodeValue(T value) - { - encoder->encodeBool(value); + case ValNodeOperandKind::ValNode: + case ValNodeOperandKind::ASTNode: + serialize(serializer, value.values.nodeOperand); + break; } +} - void encodeValue(Int32 value) { encoder->encode(value); } - void encodeValue(UInt32 value) { encoder->encode(value); } - void encodeValue(Int64 value) { encoder->encode(value); } - void encodeValue(UInt64 value) { encoder->encode(value); } - void encodeValue(float value) { encoder->encode(value); } - void encodeValue(double value) { encoder->encode(value); } - - void encodeValue(uint8_t value) { encoder->encode(UInt32(value)); } - - void encodeValue(std::nullptr_t) { encoder->encode(nullptr); } +void serializeObject(ASTSerializer const& serializer, Name*& value, Name*) +{ + serializer->handleName(value); +} - template<typename T> - void encodeEnum(T value) - { - encoder->encode(Int32(value)); - } +void serialize(ASTSerializer const& serializer, NameLoc& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.name); + serialize(serializer, value.loc); +} - void encodeValue(DeclVisibility value) { encodeEnum(value); } - void encodeValue(BaseType value) { encodeEnum(value); } - void encodeValue(BuiltinRequirementKind value) { encodeEnum(value); } - void encodeValue(ASTNodeType value) { encodeEnum(value); } - void encodeValue(ImageFormat value) { encodeEnum(value); } - void encodeValue(TypeTag value) { encodeEnum(value); } - void encodeValue(TryClauseType value) { encodeEnum(value); } - void encodeValue(CapabilityAtom value) { encodeEnum(value); } - void encodeValue(DeclAssociationKind value) { encodeEnum(value); } - void encodeValue(TokenType value) { encodeEnum(value); } - - void encodeValue(SourceLoc value) - { - if (!_sourceLocWriter) - { - encoder->encode(nullptr); - } - else - { - auto intermediate = _sourceLocWriter->addSourceLoc(value); - encoder->encode(intermediate); +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do +void _serializeASTNodeContents(ASTSerializer const& serializer, $T* value) +{ + SLANG_UNUSED(serializer); + SLANG_UNUSED(value); +% if T.directSuperClass then + _serializeASTNodeContents(serializer, static_cast<$(T.directSuperClass)*>(value)); +% end +% for _,f in ipairs(T.directFields) do + serialize(serializer, value->$f); +% end } - } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-serialize-ast.cpp.fiddle" +#endif // FIDDLE END - template<typename T> - void encodeValue(T const* ptr) - { - if (!ptr) - { - encoder->encode(nullptr); - } - else - { - encodePtr(const_cast<T*>(ptr)); - } - } +void serializeASTNodeContents(ASTSerializer const& serializer, NodeBase* node) +{ + ASTNodeDispatcher<NodeBase, void>::dispatch( + node, + [&](auto n) { _serializeASTNodeContents(serializer, n); }); +} - template<typename T> - void encodeValue(RefPtr<T> const& ptr) - { - if (!ptr) - { - encoder->encode(nullptr); - } - else - { - encodePtr(ptr.Ptr()); - } - } +enum class PseudoASTNodeType +{ + None, + ImportedModule, + ImportedDecl, +}; - void encodeValue(Modifiers const& modifiers) - { - Encoder::WithArray withArray(encoder); - for (auto m : const_cast<Modifiers&>(modifiers)) - { - encode(m); - } - } +static PseudoASTNodeType _getPseudoASTNodeType(ASTNodeType type) +{ + return int(type) < 0 ? PseudoASTNodeType(~int(type)) : PseudoASTNodeType::None; +} + +static ASTNodeType _getAsASTNodeType(PseudoASTNodeType type) +{ + return ASTNodeType(~int(type)); +} - template<typename T, int N> - void encodeValue(ShortList<T, N> const& array) +struct ASTEncodingContext : ASTSerializerImpl +{ +public: + ASTEncodingContext( + RIFF::BuildCursor& cursor, + ModuleDecl* module, + SerialSourceLocWriter* sourceLocWriter) + : _writer(cursor.getCurrentChunk()), _module(module), _sourceLocWriter(sourceLocWriter) { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } } +private: + RIFFSerialWriter _writer; + ModuleDecl* _module = nullptr; + SerialSourceLocWriter* _sourceLocWriter = nullptr; - template<typename T> - void encode(List<T> const& array) - { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } - } + virtual ISerializerImpl* getBaseSerializer() override { return &_writer; } - template<typename T, size_t N> - void encode(T const (&array)[N]) - { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } - } + virtual void handleName(Name*& value) override; + virtual void handleSourceLoc(SourceLoc& value) override; + virtual void handleToken(Token& value) override; + virtual void handleASTNode(NodeBase*& node) override; + virtual void handleASTNodeContents(NodeBase* node) override; - template<typename K, typename V> - void encode(OrderedDictionary<K, V> const& dictionary) - { - Encoder::WithArray withArray(encoder); - for (auto p : dictionary) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(p.key); - encode(p.value); - } - } + void _writeImportedModule(ModuleDecl* moduleDecl); + void _writeImportedDecl(Decl* decl, ModuleDecl* importedFromModuleDecl); - template<typename K, typename V> - void encode(Dictionary<K, V> const& dictionary) + ModuleDecl* _findModuleForDecl(Decl* decl) { - Encoder::WithArray withArray(encoder); - for (auto p : dictionary) + for (auto d = decl; d; d = d->parentDecl) { - Encoder::WithKeyValuePair withPair(encoder); - encode(p.first); - encode(p.second); + if (auto m = as<ModuleDecl>(d)) + return m; } + return nullptr; } - template<typename T> - void encode(T const& value) + ModuleDecl* _findModuleDeclWasImportedFrom(Decl* decl) { - encodeValue(value); - } - - // for each class of node, we generate - // code to recursively serialize each - // of its fields. - -#if 0 // FIDDLE TEMPLATE: -%for _,T in ipairs(Slang.NodeBase.subclasses) do - void _encodeDataOf($T* obj) - { -%if T.directSuperClass then - _encodeDataOf(static_cast<$(T.directSuperClass)*>(obj)); -%end -%for _,f in ipairs(T.directFields) do - encode(obj->$f); -%end + auto declModule = _findModuleForDecl(decl); + if (declModule == nullptr) + return nullptr; + if (declModule == _module) + return nullptr; + return declModule; } -%end -#else // FIDDLE OUTPUT: -#define FIDDLE_GENERATED_OUTPUT_ID 0 -#include "slang-serialize-ast.cpp.fiddle" -#endif // FIDDLE END }; -void writeSerializedModuleAST( - Encoder* encoder, - ModuleDecl* moduleDecl, - SerialSourceLocWriter* sourceLocWriter) -{ - Encoder::WithObject withObject(encoder); - - // TODO: we should have a more careful pass here, - // where we only encode the public declarations - // - - ASTEncodingContext context(encoder, moduleDecl, sourceLocWriter); - context.getDeclID(moduleDecl); - context.flush(); -} - -struct ASTDecodingContext +struct ASTDecodingContext : ASTSerializerImpl { public: ASTDecodingContext( @@ -673,284 +532,60 @@ public: : _linkage(linkage) , _astBuilder(astBuilder) , _sink(sink) - , _baseChunk(as<RIFF::ListChunk>(baseChunk)) , _sourceLocReader(sourceLocReader) , _requestingSourceLoc(requestingSourceLoc) + , _riffReader(baseChunk) { } +private: Linkage* _linkage = nullptr; + ASTBuilder* _astBuilder = nullptr; DiagnosticSink* _sink = nullptr; SerialSourceLocReader* _sourceLocReader = nullptr; SourceLoc _requestingSourceLoc; + RIFFSerialReader _riffReader; - SlangResult decodeAll() - { - auto cursor = _baseChunk->getChildren().begin(); - - // There are a few different top-level chunks that - // hold different arrays that we need in order - // to decode the entire module hierarchy. - // - // Basically, these lists correspond to the kinds - // of nodes in the AST hierarchy for which back-references - // are allowed (all other nodes should, barring - // weird corner cases, form a single tree-structured - // ownership hierarchy, rooted at the `ModuleDecl`. - // - - // First there is the list that actually encodes - // for the declarations in the module, including - // the `ModuleDecl` itself, which should be the - // first entry in the list. - // - auto declChunk = *cursor; - ++cursor; - - // Next there is a list of all the declarations - // referenced inside of the module that need to - // be imported in from outside. - // - auto importedDeclChunk = *cursor; - ++cursor; + virtual ISerializerImpl* getBaseSerializer() override { return &_riffReader; } - // Then there are all the `Val`-derived nodes that - // are needed by the module, which will need to be - // deduplicated so that they are unique within the - // current compilation context. - // - auto valChunk = *cursor; - ++cursor; + virtual void handleName(Name*& value) override; + virtual void handleSourceLoc(SourceLoc& value) override; + virtual void handleToken(Token& value) override; + virtual void handleASTNode(NodeBase*& outNode) override; + virtual void handleASTNodeContents(NodeBase* node) override; - // The process of decoding the module is then spread - // over a number of steps. - // - // The first step is to process all of the imported - // declarations, so that other nodes can refer to - // them. - // - SLANG_RETURN_ON_FAIL(decodeImportedDecls(importedDeclChunk)); - - // Next we process the declarations that are within - // the module itself, first creating an "empty shell" - // of each declaration that has the right size in - // memory (and the right `ASTNodeType` tag), so that - // we can wire up references to it (including circular - // references)... so long as nothing here tries to - // look *inside* the empty shell along the way. - // - SLANG_RETURN_ON_FAIL(createEmptyShells(declChunk)); - - // Once all the `Decl`s that might be needed have - // been allocated, we can process all the `Val`s - // that might reference those`Decl`s (and one another). - // - // The nature of the `Val` representation ensures - // that there cannot be cirularities in the references - // between `Val`s, and the encoding process will have - // sorted the entries so that a `Val` only ever appears - // *after* its operands. - // - SLANG_RETURN_ON_FAIL(decodeVals(valChunk)); + ModuleDecl* _readImportedModule(); + NodeBase* _readImportedDecl(); - // Once all the back-reference-able objects have been - // instantiated in memory, we can go back through the - // `Decl`s in the module and fill in those empty shells. - // - SLANG_RETURN_ON_FAIL(fillEmptyShells(declChunk)); - - // As a final pass, we perform any special cleanup actions - // that might be required to make the output valid for consumers. - // - // For example, this is where we set the `DeclCheckState` of everything - // we are loading to reflect the fact that everything we deserialize - // is (supposed to be) fully cheked. - // - SLANG_RETURN_ON_FAIL(cleanUpNodes()); - - - return SLANG_OK; - } - - typedef Int DeclID; - Decl* getDeclByID(DeclID id) + void _cleanUpASTNode(NodeBase* node) { - if (id >= 0) - { - return _decls[id]; - } - else + if (auto expr = as<Expr>(node)) { - return _importedDecls[~id]; + expr->checked = true; } - } - -private: - struct UnhandledCase - { - }; - - ASTBuilder* _astBuilder = nullptr; - RIFF::ListChunk const* _baseChunk = nullptr; - - List<Decl*> _decls; - List<Decl*> _importedDecls; - List<Val*> _vals; - - typedef Int ValID; - Val* getValByID(ValID id) { return _vals[id]; } - - SlangResult decodeImportedDecls(RIFF::Chunk const* importedDeclChunk) - { - Decoder decoder(importedDeclChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) + else if (auto decl = as<Decl>(node)) { - Decoder::WithKeyValuePair withPair(decoder); - - Int moduleIndex; - decode(moduleIndex, decoder); + decl->checkState = DeclCheckState::CapabilityChecked; - if (moduleIndex == -1) + if (auto genericDecl = as<GenericDecl>(node)) { - Name* moduleName = nullptr; - decode(moduleName, decoder); - - Decl* importedModule = getImportedModule(moduleName); - _importedDecls.add(importedModule); + _assignGenericParameterIndices(genericDecl); } - else + else if (auto syntaxDecl = as<SyntaxDecl>(node)) { - auto importedFromModuleDecl = as<ModuleDecl>(_importedDecls[moduleIndex]); - auto importedFromModule = importedFromModuleDecl->module; - - String mangledName; - decode(mangledName, decoder); - - auto importedNode = - importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); - auto importedDecl = as<Decl>(importedNode); - _importedDecls.add(importedDecl); + syntaxDecl->parseCallback = &parseSimpleSyntax; + syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); } - } - return SLANG_OK; - } - - ModuleDecl* getImportedModule(Name* moduleName) - { - Module* module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); - if (!module) - { - SLANG_ABORT_COMPILATION("failed to load an imported module during deserialization"); - } - - return module->getModuleDecl(); - } - - SlangResult decodeVals(RIFF::Chunk const* valChunk) - { - Decoder decoder(valChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Val* val = decodeValNode(decoder); - _vals.add(val); - } - return SLANG_OK; - } - - SlangResult createEmptyShells(RIFF::Chunk const* declChunk) - { - Decoder decoder(declChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - ASTNodeType nodeType; - - // Each of the declarations is expected to take - // the form of an object with a first field - // that holds the node type. - // + else if (auto namespaceLikeDecl = as<NamespaceDeclBase>(node)) { - Decoder::WithObject withObject(decoder); - decode(nodeType, decoder); + auto declScope = _astBuilder->create<Scope>(); + declScope->containerDecl = namespaceLikeDecl; + namespaceLikeDecl->ownedScope = declScope; } - - auto emptyShell = createEmptyShell(nodeType); - auto declEmptyShell = as<Decl>(emptyShell); - _decls.add(declEmptyShell); } - - return SLANG_OK; } - Val* decodeValNode(Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - - ASTNodeType nodeType; - decode(nodeType, decoder); - - ValNodeDesc desc; - desc.type = SyntaxClass<NodeBase>(nodeType); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - ValNodeOperand operand; - decode(operand, decoder); - desc.operands.add(operand); - } - - desc.init(); - - auto val = _astBuilder->_getOrCreateImpl(_Move(desc)); - - // Values created during deserialization are - // not expected to ever resolve further, because - // they should be coming from fully checked code. - // - // val->resolve(); - // val->_setUnique(); - - return val; - } - - NodeBase* createEmptyShell(ASTNodeType nodeType) - { - return SyntaxClass<NodeBase>(nodeType).createInstance(_astBuilder); - } - - SlangResult fillEmptyShells(RIFF::Chunk const* declChunk) - { - Index declIndex = 0; - - Decoder decoder(declChunk); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - auto declEmptyShell = _decls[declIndex++]; - decodeASTNodeContent(declEmptyShell, decoder); - } - - return SLANG_OK; - } - - SlangResult cleanUpNodes() - { - for (auto decl : _decls) - { - decl->checkState = DeclCheckState::CapabilityChecked; - } - - return SLANG_OK; - } - - - void assignGenericParameterIndices(GenericDecl* genericDecl) + void _assignGenericParameterIndices(GenericDecl* genericDecl) { int parameterCounter = 0; for (auto m : genericDecl->members) @@ -965,576 +600,314 @@ private: } } } +}; +// +// We are matching up the corresponding `handle*()` operations from the +// `AST{Encoding|Decoding}Context` types here, so that it is easier +// to visually verify that they are serializing the same data with the +// same ordering. +// - void cleanUpASTNode(NodeBase* node) - { - if (auto expr = as<Expr>(node)) - { - expr->checked = true; - } - else if (auto genericDecl = as<GenericDecl>(node)) - { - assignGenericParameterIndices(genericDecl); - } - else if (auto syntaxDecl = as<SyntaxDecl>(node)) - { - syntaxDecl->parseCallback = &parseSimpleSyntax; - syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); - } - else if (auto namespaceLikeDecl = as<NamespaceDeclBase>(node)) - { - auto declScope = _astBuilder->create<Scope>(); - declScope->containerDecl = namespaceLikeDecl; - namespaceLikeDecl->ownedScope = declScope; - } - } - - void decodeASTNodeContent(NodeBase* node, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - - ASTNodeDispatcher<NodeBase, void>::dispatch( - node, - [&](auto n) { _decodeDataOf(n, decoder); }); +// +// AST{Encoding|Decoding}Context::handleName() +// - cleanUpASTNode(node); - } +void ASTEncodingContext::handleName(Name*& value) +{ + serialize(ASTSerializer(this), value->text); +} - DeclID decodeDeclID(Decoder& decoder) - { - DeclID result = decoder.decode<DeclID>(); - return result; - } +void ASTDecodingContext::handleName(Name*& value) +{ + String text; + serialize(ASTSerializer(this), text); + value = _astBuilder->getNamePool()->getName(text); +} - ValID decodeValID(Decoder& decoder) - { - ValID result = decoder.decode<ValID>(); - return result; - } +// +// AST{Encoding|Decoding}Context::handleSourceLoc() +// - template<typename T> - void decodeASTNode(T*& node, Decoder& decoder) +void ASTEncodingContext::handleSourceLoc(SourceLoc& value) +{ + ASTSerializer serializer(this); + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (_sourceLocWriter != nullptr) { - ASTNodeType nodeType; - auto saved = decoder.getCursor(); - { - Decoder::WithObject withObject(decoder); - decode(nodeType, decoder); - } - decoder.setCursor(saved); - - auto shell = createEmptyShell(nodeType); - decodeASTNodeContent(shell, decoder); - - node = as<T>(shell); + auto rawValue = _sourceLocWriter->addSourceLoc(value); + serialize(serializer, rawValue); } +} - void decodePtr(Name*& name, Decoder& decoder, Name*) +void ASTDecodingContext::handleSourceLoc(SourceLoc& value) +{ + ASTSerializer serializer(this); + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (hasElements(serializer)) { - String text; - decode(text, decoder); - - name = _astBuilder->getNamePool()->getName(text); - } + SerialSourceLocData::SourceLoc rawValue; + serialize(serializer, rawValue); - void decodePtr(DeclAssociationList*& outList, Decoder& decoder, DeclAssociationList*) - { - // Mirroring the encoding logic, we decode this - // as a list of key-value pairs. - // - auto list = RefPtr(new DeclAssociationList()); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) + if (_sourceLocReader) { - auto association = RefPtr(new DeclAssociation()); - - Decoder::WithKeyValuePair withPair(decoder); - decode(association->kind, decoder); - decode(association->decl, decoder); - - list->associations.add(association); + value = _sourceLocReader->getSourceLoc(rawValue); } - - outList = list.detach(); } +} - void decodePtr(DiagnosticInfo const*& info, Decoder& decoder, DiagnosticInfo const*) - { - Int id; - decode(id, decoder); - info = getDiagnosticsLookup()->getDiagnosticById(id); - } - - void decodePtr(MarkupEntry*& markupEntry, Decoder&, MarkupEntry*) - { - // TODO: is this case needed? - markupEntry = nullptr; - } - - void decodePtr(CandidateExtensionList*& list, Decoder& decoder, CandidateExtensionList*) - { - auto result = RefPtr(new CandidateExtensionList()); - decode(result->candidateExtensions, decoder); - list = result.detach(); - } - - void decodePtr(WitnessTable*& witnessTable, Decoder& decoder, WitnessTable*) - { - Decoder::WithObject withObject(decoder); - auto wt = RefPtr(new WitnessTable()); - decode(wt->baseType, decoder); - decode(wt->witnessedType, decoder); - decode(wt->isExtern, decoder); - decode(wt->m_requirementDictionary, decoder); - witnessTable = wt.detach(); - } - - void decodeValue(RequirementWitness& witness, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decodeEnum(witness.m_flavor, decoder); - switch (witness.m_flavor) - { - case RequirementWitness::Flavor::none: - break; - - case RequirementWitness::Flavor::declRef: - decode(witness.m_declRef, decoder); - break; - - case RequirementWitness::Flavor::val: - decode(witness.m_val, decoder); - break; +// +// AST{Encoding|Decoding}Context::handleToken() +// - case RequirementWitness::Flavor::witnessTable: - { - RefPtr<WitnessTable> object; - decode(object, decoder); - witness.m_obj = object; - } - break; - } - } +void ASTDecodingContext::handleToken(Token& value) +{ + ASTSerializer serializer(this); - template<typename T> - void decodePtr(T*& node, Decoder& decoder, Val*) - { - ValID id = decodeValID(decoder); - node = static_cast<T*>(getValByID(id)); - } + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.loc); - template<typename T> - void decodePtr(T*& node, Decoder& decoder, Decl*) - { - DeclID id = decodeDeclID(decoder); - node = static_cast<T*>(getDeclByID(id)); - } + serialize(serializer, value.flags); - template<typename T> - void decodePtr(T*& node, Decoder& decoder, DeclBase*) { - // This case is a bit of a hack. We need - // to identify whether we are looking at - // an indirection to a `Decl` (which would - // be serialized as an integer `DeclID`), - // or something else derived from `DeclBase`. - // - switch (decoder.getTag()) + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (hasElements(serializer)) { - default: - decodeASTNode(node, decoder); - break; - - case SerialBinary::kInt32FourCC: - case SerialBinary::kInt64FourCC: - case SerialBinary::kUInt32FourCC: - case SerialBinary::kUInt64FourCC: - { - DeclID id = decodeDeclID(decoder); - node = static_cast<T*>(getDeclByID(id)); - } - break; - } - } - - template<typename T> - void decodePtr(T*& node, Decoder& decoder, NodeBase*) - { - decodeASTNode(node, decoder); - } + String content; + serialize(serializer, content); - - void decodeValue(UnhandledCase, Decoder& decoder); - - void decodeValue(String& value, Decoder& decoder) { value = decoder.decodeString(); } - - void decodeValue(Token& value, Decoder& decoder) - { - decode(value.type, decoder); - decode(value.flags, decoder); - decode(value.loc, decoder); - if (decoder.decodeNull()) - { - } - else - { - Name* name = nullptr; - decode(name, decoder); + // An important note here is that we cannot just + // call `value.setContent(...)` and pass in an + // `UnownedStringSlice` of `content`, because the + // `Token` will not take ownership of its own + // textual content. + // + // Instead, we need to get the text we just loaded + // into something that the `Token` can refer info, + // and the easiest way to accomplish that is to + // represent the text using a `Name`. + // + Name* name = _astBuilder->getNamePool()->getName(content); value.setName(name); } } +} - void decodeValue(NameLoc& value, Decoder& decoder) { decode(value.name, decoder); } - - void decodeValue(SemanticVersion& value, Decoder& decoder) - { - SemanticVersion::RawValue rawValue = decoder.decode<SemanticVersion::RawValue>(); - value.setRawValue(rawValue); - } - - void decodeValue(CapabilitySet& value, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - CapabilityTargetSet targetSet; - decode(targetSet, decoder); - value.getCapabilityTargetSets()[targetSet.target] = targetSet; - } - } - - void decodeValue(CapabilityTargetSet& value, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decode(value.target, decoder); +void ASTEncodingContext::handleToken(Token& value) +{ + ASTSerializer serializer(this); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - CapabilityStageSet stageSet; - decode(stageSet, decoder); - value.shaderStageSets[stageSet.stage] = stageSet; - } - } + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.loc); - void decodeValue(CapabilityStageSet& value, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decode(value.stage, decoder); - decode(value.atomSet, decoder); - } + TokenFlags flags = TokenFlags(value.flags & ~TokenFlag::Name); + serialize(serializer, flags); - void decodeValue(CapabilityAtomSet& value, Decoder& decoder) { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - CapabilityAtom atom; - decode(atom, decoder); - value.add(UInt(atom)); - } - } - - template<typename T> - void decodeValue(std::optional<T>& outValue, Decoder& decoder) - { - if (decoder.decodeNull()) - { - outValue.reset(); - } - else + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (value.hasContent()) { - T value; - decode(value, decoder); - outValue = value; + String content = value.getContent(); + serialize(serializer, content); } } +} - void decodeValue(SyntaxClass<NodeBase>& syntaxClass, Decoder& decoder) - { - ASTNodeType nodeType; - decode(nodeType, decoder); - syntaxClass = SyntaxClass<NodeBase>(nodeType); - } - - template<typename T> - void decodeValue(DeclRef<T>& declRef, Decoder& decoder) - { - decode(declRef.declRefBase, decoder); - } +// +// AST{Encoding|Decoding}Context::handleASTNode() +// - void decodeValue(ValNodeOperand& value, Decoder& decoder) +void ASTEncodingContext::handleASTNode(NodeBase*& node) +{ + if (auto decl = as<Decl>(node)) { - Decoder::WithKeyValuePair withPair(decoder); - - decodeEnum(value.kind, decoder); - switch (value.kind) + if (auto importedFromModule = _findModuleDeclWasImportedFrom(decl)) { - case ValNodeOperandKind::ConstantValue: - decode(value.values.intOperand, decoder); - break; - - case ValNodeOperandKind::ValNode: + if (decl == importedFromModule) { - Val* val = nullptr; - decode(val, decoder); - value.values.nodeOperand = val; + _writeImportedModule(importedFromModule); + return; } - break; - - case ValNodeOperandKind::ASTNode: + else { - Decl* decl = nullptr; - decode(decl, decoder); - value.values.nodeOperand = decl; + _writeImportedDecl(decl, importedFromModule); + return; } - break; } } - void decodeValue(TypeExp& value, Decoder& decoder) { decode(value.type, decoder); } + ASTSerializer serializer(this); - void decodeValue(QualType& value, Decoder& decoder) + if (auto val = as<Val>(node)) { - Decoder::WithObject withObject(decoder); - decode(value.type, decoder); - decode(value.isLeftValue, decoder); - decode(value.hasReadOnlyOnTarget, decoder); - decode(value.isWriteOnly, decoder); - } + val = val->resolve(); - void decodeValue(MatrixCoord& value, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - decode(value.row, decoder); - decode(value.col, decoder); + // On the reading side of things, sublcasses of `Val` + // are deduplicated as part of creation, and will read the + // operands out immediately, so we mirror that approach + // on the writing side to make sure the code is consistent. + // + serialize(serializer, val->astNodeType); + serialize(serializer, val->m_operands); } - - void decodeValue(SPIRVAsmOperand::Flavor& value, Decoder& decoder) + else { - decodeEnum(value, decoder); + serialize(serializer, node->astNodeType); + deferSerializeObjectContents(serializer, node); } +} - void decodeValue(SPIRVAsmOperand& value, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - decode(value.flavor, decoder); - decode(value.token, decoder); - decode(value.expr, decoder); - decode(value.bitwiseOrWith, decoder); - decode(value.knownValue, decoder); - decode(value.wrapInId, decoder); - decode(value.type, decoder); - } +void ASTDecodingContext::handleASTNode(NodeBase*& outNode) +{ + ASTSerializer serializer(this); - void decodeValue(SPIRVAsmInst& value, Decoder& decoder) + ASTNodeType typeTag; + serialize(serializer, typeTag); + switch (_getPseudoASTNodeType(typeTag)) { - Decoder::WithObject withObject(decoder); - decode(value.opcode, decoder); - decode(value.operands, decoder); - } + default: + break; + case PseudoASTNodeType::ImportedModule: + outNode = _readImportedModule(); + return; - template<typename T> - void decodeEnum(T& value, Decoder& decoder) - { - value = T(decoder.decode<Int32>()); + case PseudoASTNodeType::ImportedDecl: + outNode = _readImportedDecl(); + return; } - template<typename T> - void decodeSimpleValue(T& value, Decoder& decoder) + auto syntaxClass = SyntaxClass<NodeBase>(typeTag); + if (syntaxClass.isSubClassOf<Val>()) { - value = decoder.decode<T>(); - } + // Subclasses of `Val` are deduplicated as part + // of creation, so we need to read in their + // operands before we can create them, rather + // than allocating the object up front and + // then deserializing its content into it later. - void decodeValue(bool& value, Decoder& decoder) { value = decoder.decodeBool(); } - void decodeValue(Int32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(Int64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(UInt32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(UInt64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(float& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(double& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + ValNodeDesc desc; + desc.type = syntaxClass; + serialize(serializer, desc.operands); - void decodeValue(uint8_t& value, Decoder& decoder) - { - value = uint8_t(decoder.decode<UInt32>()); - } + desc.init(); - void decodeValue(DeclVisibility& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(BaseType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(BuiltinRequirementKind& value, Decoder& decoder) - { - decodeEnum(value, decoder); - } - void decodeValue(ASTNodeType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(ImageFormat& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TypeTag& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TryClauseType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(CapabilityAtom& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(PreferRecomputeAttribute::SideEffectBehavior& value, Decoder& decoder) - { - decodeEnum(value, decoder); + auto node = _astBuilder->_getOrCreateImpl(std::move(desc)); + outNode = node; } - void decodeValue(LogicOperatorShortCircuitExpr::Flavor& value, Decoder& decoder) + else { - decodeEnum(value, decoder); - } - void decodeValue(TreatAsDifferentiableExpr::Flavor& value, Decoder& decoder) - { - decodeEnum(value, decoder); - } - void decodeValue(DeclAssociationKind& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TokenType& value, Decoder& decoder) { decodeEnum(value, decoder); } + auto node = syntaxClass.createInstance(_astBuilder); + outNode = node; - - void decodeValue(SourceLoc& value, Decoder& decoder) - { - if (!decoder.decodeNull()) - { - SerialSourceLocData::SourceLoc intermediate; - decoder.decode(intermediate); - - if (_sourceLocReader) - { - auto sourceLoc = _sourceLocReader->getSourceLoc(intermediate); - value = sourceLoc; - } - } + deferSerializeObjectContents(serializer, node); } +} - template<typename T> - void decodeValue(T*& ptr, Decoder& decoder) - { - if (decoder.decodeNull()) - ptr = nullptr; - else - decodePtr(ptr, decoder, (T*)nullptr); - } +// +// AST{Encoding|Decoding}Context::handleASTNodeContents() +// - template<typename T> - void decodeValue(RefPtr<T>& ptr, Decoder& decoder) - { - if (decoder.decodeNull()) - ptr = nullptr; - else - { - // Hi Future Tess, - // - // The next step here is decoding logic for `WitnessTable`s. - // +void ASTEncodingContext::handleASTNodeContents(NodeBase* node) +{ + ASTSerializer serializer(this); + serializeASTNodeContents(serializer, node); +} - decodePtr(*ptr.writeRef(), decoder, (T*)nullptr); - } - } +void ASTDecodingContext::handleASTNodeContents(NodeBase* node) +{ + ASTSerializer serializer(this); + serializeASTNodeContents(serializer, node); - void decodeValue(Modifiers& modifiers, Decoder& decoder) - { - Modifier** link = &modifiers.first; + _cleanUpASTNode(node); +} - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Modifier* modifier = nullptr; - decode(modifier, decoder); +// +// AST{Encoding|Decoding}Context::_{write|read}ImportedModule() +// - *link = modifier; - link = &modifier->next; - } - } +void ASTEncodingContext::_writeImportedModule(ModuleDecl* moduleDecl) +{ + ASTNodeType type = _getAsASTNodeType(PseudoASTNodeType::ImportedModule); + auto moduleName = moduleDecl->getName(); - template<typename T, int N> - void decodeValue(ShortList<T, N>& array, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - T element; - decode(element, decoder); - array.add(element); - } - } + ASTSerializer serializer(this); + serialize(serializer, type); + serialize(serializer, moduleName); +} +ModuleDecl* ASTDecodingContext::_readImportedModule() +{ + ASTSerializer serializer(this); - template<typename T> - void decode(List<T>& array, Decoder& decoder) + Name* moduleName = nullptr; + serialize(serializer, moduleName); + auto module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); + if (!module) { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - T element; - decode(element, decoder); - array.add(element); - } + SLANG_ABORT_COMPILATION("failed to load an imported module during AST deserialization"); } + return module->getModuleDecl(); +} - template<typename T, size_t N> - void decode(T (&array)[N], Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - for (auto& element : array) - { - decode(element, decoder); - } - } +// +// AST{Encoding|Decoding}Context::_{write|read}ImportedModule() +// - template<typename K, typename V> - void decode(OrderedDictionary<K, V>& dictionary, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Decoder::WithKeyValuePair withPair(decoder); +void ASTEncodingContext::_writeImportedDecl(Decl* decl, ModuleDecl* importedFromModuleDecl) +{ + ASTNodeType type = _getAsASTNodeType(PseudoASTNodeType::ImportedDecl); + auto mangledName = getMangledName(getCurrentASTBuilder(), decl); - K key; - V value; - decode(key, decoder); - decode(value, decoder); + ASTSerializer serializer(this); + serialize(serializer, type); + serialize(serializer, importedFromModuleDecl); + serialize(serializer, mangledName); +} - dictionary.add(key, value); - } - } +NodeBase* ASTDecodingContext::_readImportedDecl() +{ + ASTSerializer serializer(this); - template<typename K, typename V> - void decode(Dictionary<K, V>& dictionary, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Decoder::WithKeyValuePair withPair(decoder); + ModuleDecl* importedFromModuleDecl = nullptr; + String mangledName; - K key; - V value; - decode(key, decoder); - decode(value, decoder); + serialize(serializer, importedFromModuleDecl); + serialize(serializer, mangledName); - dictionary.add(key, value); - } + auto importedFromModule = importedFromModuleDecl->module; + if (!importedFromModule) + { + return nullptr; } - template<typename T> - void decode(T& outValue, Decoder& decoder) + auto importedDecl = + importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); + if (!importedDecl) { - decodeValue(outValue, decoder); + SLANG_ABORT_COMPILATION( + "failed to load an imported declaration during AST deserialization"); } + return importedDecl; +} -#if 0 // FIDDLE TEMPLATE: -%for _,T in ipairs(Slang.NodeBase.subclasses) do - void _decodeDataOf($T* obj, Decoder& decoder) - { -% if T.directSuperClass then - _decodeDataOf(static_cast<$(T.directSuperClass)*>(obj), decoder); -% end -% for _,f in ipairs(T.directFields) do - decode(obj->$f, decoder); -% end - } -%end -#else // FIDDLE OUTPUT: -#define FIDDLE_GENERATED_OUTPUT_ID 1 -#include "slang-serialize-ast.cpp.fiddle" -#endif // FIDDLE END -}; +// +// {write|read}SerializedModuleAST() +// + +void writeSerializedModuleAST( + RIFF::BuildCursor& cursor, + ModuleDecl* moduleDecl, + SerialSourceLocWriter* sourceLocWriter) +{ + // TODO: we might want to have a more careful pass here, + // where we only encode the public declarations. + + ASTEncodingContext context(cursor, moduleDecl, sourceLocWriter); + serialize(ASTSerializer(&context), moduleDecl); +} ModuleDecl* readSerializedModuleAST( Linkage* linkage, @@ -1546,9 +919,10 @@ ModuleDecl* readSerializedModuleAST( { ASTDecodingContext context(linkage, astBuilder, sink, chunk, sourceLocReader, requestingSourceLoc); - context.decodeAll(); - auto node = context.getDeclByID(0); - auto moduleDecl = as<ModuleDecl>(node); + + ModuleDecl* moduleDecl = nullptr; + serialize(ASTSerializer(&context), moduleDecl); return moduleDecl; } + } // namespace Slang |
