diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/core/slang-riff.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ast-base.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-dump.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-serialize-ast.cpp | 1932 | ||||
| -rw-r--r-- | source/slang/slang-serialize-ast.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-serialize-container.cpp | 149 | ||||
| -rw-r--r-- | source/slang/slang-serialize-riff.cpp | 897 | ||||
| -rw-r--r-- | source/slang/slang-serialize-riff.h | 431 | ||||
| -rw-r--r-- | source/slang/slang-serialize.h | 1308 |
11 files changed, 2929 insertions, 1816 deletions
diff --git a/source/core/slang-riff.h b/source/core/slang-riff.h index 9c533aeb8..0f820c747 100644 --- a/source/core/slang-riff.h +++ b/source/core/slang-riff.h @@ -1007,11 +1007,15 @@ private: ChunkBuilder* _currentChunk = nullptr; }; -#define SLANG_SCOPED_RIFF_BUILDER_DATA_CHUNK(CURSOR, TYPE) \ - ::Slang::RIFF::BuildCursor::ScopedDataChunk _scopedRIFFBuilderDataChunk(CURSOR, TYPE) - -#define SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(CURSOR, TYPE) \ - ::Slang::RIFF::BuildCursor::ScopedListChunk _scopedRIFFBuilderListChunk(CURSOR, TYPE) +#define SLANG_SCOPED_RIFF_BUILDER_DATA_CHUNK(CURSOR, TYPE) \ + ::Slang::RIFF::BuildCursor::ScopedDataChunk SLANG_CONCAT( \ + _scopedRIFFBuilderDataChunk, \ + __LINE__)(CURSOR, TYPE) + +#define SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(CURSOR, TYPE) \ + ::Slang::RIFF::BuildCursor::ScopedListChunk SLANG_CONCAT( \ + _scopedRIFFBuilderListChunk, \ + __LINE__)(CURSOR, TYPE) } // namespace RIFF diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 5affcb756..47ebc8a9b 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -41,7 +41,7 @@ class NodeBase /// The type of the node. ASTNodeType(-1) is an invalid node type, and shouldn't appear on any /// correctly constructed (through ASTBuilder) NodeBase derived class. /// The actual type is set when constructed on the ASTBuilder. - FIDDLE() ASTNodeType astNodeType = ASTNodeType(-1); + ASTNodeType astNodeType = ASTNodeType(-1); #ifdef _DEBUG SLANG_UNREFLECTED int32_t _debugUID = 0; @@ -752,7 +752,7 @@ public: FIDDLE() NameLoc nameAndLoc; FIDDLE() CapabilitySet inferredCapabilityRequirements; - FIDDLE() RefPtr<MarkupEntry> markup; + RefPtr<MarkupEntry> markup; Name* getName() const { return nameAndLoc.name; } SourceLoc getNameLoc() const { return nameAndLoc.loc; } diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 6fb281247..46d9147a0 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -758,7 +758,7 @@ class DerivativeRequirementDecl : public FunctionDeclBase // A reference to a synthesized decl representing a differentiable function requirement, this decl // will be a child in the orignal function. FIDDLE() -class DerivativeRequirementReferenceDecl : public FunctionDeclBase +class DerivativeRequirementReferenceDecl : public Decl { FIDDLE(...) FIDDLE() DerivativeRequirementDecl* referencedDecl; diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index 24b10344d..7f6f7796c 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -775,6 +775,8 @@ struct ASTDumpAccess %for _,T in ipairs(Slang.NodeBase.subclasses) do static void dump_($T * node, ASTDumpContext & context) { + SLANG_UNUSED(node); + SLANG_UNUSED(context); % if T.directSuperClass then dump_(static_cast<$(T.directSuperClass)*>(node), context); % end diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index b5ebe1884..6f82a534a 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1564,7 +1564,7 @@ FIDDLE() namespace Slang Flavor m_flavor; DeclRef<Decl> m_declRef; - RefPtr<RefObject> m_obj; + RefPtr<WitnessTable> m_obj; Val* m_val = nullptr; }; diff --git a/source/slang/slang-serialize-ast.cpp b/source/slang/slang-serialize-ast.cpp index 1ef532ad1..9a61dbc5a 100644 --- a/source/slang/slang-serialize-ast.cpp +++ b/source/slang/slang-serialize-ast.cpp @@ -5,6 +5,7 @@ #include "slang-compiler.h" #include "slang-diagnostics.h" #include "slang-mangle.h" +#include "slang-serialize-riff.h" namespace Slang { @@ -13,654 +14,512 @@ namespace Slang // NodeBase* parseSimpleSyntax(Parser* parser, void* userData); +// +// Many of the types used in the AST can be serialized using +// just the `Serializer` type, so we will handle all of those first. +// -struct ASTEncodingContext +void serialize(Serializer const& serializer, ASTNodeType& value) { -private: - Encoder* encoder; - struct UnhandledCase - { - }; - - typedef Int DeclID; - Dictionary<Decl*, DeclID> mapDeclToID; - List<Decl*> decls; - - struct ImportedDeclInfo - { - Int moduleIndex = -1; - Decl* decl; - }; - List<ImportedDeclInfo> importedDecls; - - typedef Int ValID; - Dictionary<Val*, ValID> mapValToID; - List<Val*> vals; - - ModuleDecl* _module = nullptr; - - SerialSourceLocWriter* _sourceLocWriter = nullptr; + serializeEnum(serializer, value); +} -public: - ASTEncodingContext(Encoder* encoder, ModuleDecl* module, SerialSourceLocWriter* sourceLocWriter) - : encoder(encoder), _module(module), _sourceLocWriter(sourceLocWriter) - { - } +void serialize(Serializer const& serializer, TypeTag& value) +{ + serializeEnum(serializer, value); +} - template<typename T> - void encodeASTNodeContent(T* node) - { - Encoder::WithObject withObject(encoder); +void serialize(Serializer const& serializer, BaseType& value) +{ + serializeEnum(serializer, value); +} - ASTNodeDispatcher<T, void>::dispatch(node, [&](auto n) { _encodeDataOf(n); }); - } +void serialize(Serializer const& serializer, TryClauseType& value) +{ + serializeEnum(serializer, value); +} - void flush() - { - auto containerChunk = encoder->getRIFFChunk(); +void serialize(Serializer const& serializer, DeclVisibility& value) +{ + serializeEnum(serializer, value); +} - RIFF::ChunkBuilder* declChunk = nullptr; - RIFF::ChunkBuilder* importedDeclChunk = nullptr; - RIFF::ChunkBuilder* valChunk = nullptr; - { - Encoder::WithArray withList(encoder); - declChunk = encoder->getRIFFChunk(); - } - { - Encoder::WithArray withList(encoder); - importedDeclChunk = encoder->getRIFFChunk(); - } - { - Encoder::WithArray withList(encoder); - valChunk = encoder->getRIFFChunk(); - } - Int declIndex = 0; - Int importedDeclIndex = 0; - Int valIndex = 0; +void serialize(Serializer const& serializer, BuiltinRequirementKind& value) +{ + serializeEnum(serializer, value); +} - bool done = false; - do - { - done = true; - while (declIndex < decls.getCount()) - { - done = false; - encoder->setRIFFChunk(declChunk); - encodeASTNodeContent(decls[declIndex++]); - } - while (importedDeclIndex < importedDecls.getCount()) - { - done = false; - encoder->setRIFFChunk(importedDeclChunk); - encodeImportedDecl(importedDecls[importedDeclIndex++]); - } - while (valIndex < vals.getCount()) - { - done = false; - encoder->setRIFFChunk(valChunk); - encodeASTNodeContent(vals[valIndex++]); - } - } while (!done); +void serialize(Serializer const& serializer, ImageFormat& value) +{ + serializeEnum(serializer, value); +} - encoder->setRIFFChunk(containerChunk); - } +void serialize(Serializer const& serializer, PreferRecomputeAttribute::SideEffectBehavior& value) +{ + serializeEnum(serializer, value); +} - ModuleDecl* findModuleForDecl(Decl* decl) - { - for (auto d = decl; d; d = d->parentDecl) - { - if (auto m = as<ModuleDecl>(d)) - return m; - } - return nullptr; - } +void serialize(Serializer const& serializer, TreatAsDifferentiableExpr::Flavor& value) +{ + serializeEnum(serializer, value); +} - ModuleDecl* findModuleDeclWasImportedFrom(Decl* decl) - { - auto declModule = findModuleForDecl(decl); - if (declModule == nullptr) - return nullptr; - if (declModule == _module) - return nullptr; - return declModule; - } +void serialize(Serializer const& serializer, LogicOperatorShortCircuitExpr::Flavor& value) +{ + serializeEnum(serializer, value); +} - DeclID getDeclID(Decl* decl) - { - SLANG_ASSERT(decl != nullptr); +void serialize(Serializer const& serializer, RequirementWitness::Flavor& value) +{ + serializeEnum(serializer, value); +} - if (auto found = mapDeclToID.tryGetValue(decl)) - return *found; +void serialize(Serializer const& serializer, CapabilityAtom& value) +{ + serializeEnum(serializer, value); +} - // We need to detect whether the declaration is an - // imported one, or one from this module itself. - // - // Imported declarations need to be handled very - // differently, since they'll involve resolving - // references to those other modules, and the - // declarations within them. - // - if (auto importedFromModule = findModuleDeclWasImportedFrom(decl)) - { - DeclID importedFromModuleDeclID = 0; - if (decl != importedFromModule) - { - importedFromModuleDeclID = getDeclID(importedFromModule); - } +void serialize(Serializer const& serializer, DeclAssociationKind& value) +{ + serializeEnum(serializer, value); +} - DeclID id = ~importedDecls.getCount(); - mapDeclToID.add(decl, id); +void serialize(Serializer const& serializer, TokenType& value) +{ + serializeEnum(serializer, value); +} - ImportedDeclInfo info; - info.moduleIndex = ~importedFromModuleDeclID; - info.decl = decl; - importedDecls.add(info); +void serialize(Serializer const& serializer, ValNodeOperandKind& value) +{ + serializeEnum(serializer, value); +} - return id; - } - else - { - DeclID id = decls.getCount(); - decls.add(decl); - mapDeclToID.add(decl, id); +void serialize(Serializer const& serializer, SPIRVAsmOperand::Flavor& value) +{ + serializeEnum(serializer, value); +} - return id; - } - } +void serialize(Serializer const& serializer, MatrixCoord& value) +{ + SLANG_SCOPED_SERIALIZER_TUPLE(serializer); + serialize(serializer, value.row); + serialize(serializer, value.col); +} - void encodePtr(Decl* decl) +void serializePtr(Serializer const& serializer, DiagnosticInfo const*& value, DiagnosticInfo const*) +{ + Int32 id = 0; + if (isWriting(serializer)) { - DeclID id = getDeclID(decl); - encoder->encode(id); + id = value->id; + serialize(serializer, id); } - - ValID getValID(Val* val) + else { - SLANG_ASSERT(val != nullptr); - - if (auto found = mapValToID.tryGetValue(val)) - return *found; - - // In order to ensure that values can be fully constructed - // from the get-go (so that they will get cached correctly), - // we conspire to ensure that every value is preceded by - // all of its operands. - // - for (auto operand : val->m_operands) - { - switch (operand.kind) - { - default: - break; - - case ValNodeOperandKind::ValNode: - if (auto operandNode = operand.values.nodeOperand) - { - SLANG_ASSERT(as<Val>(operandNode)); - getValID(static_cast<Val*>(operandNode)); - } - break; - - case ValNodeOperandKind::ASTNode: - if (auto operandNode = operand.values.nodeOperand) - { - SLANG_ASSERT(as<Decl>(operandNode)); - getDeclID(static_cast<Decl*>(operandNode)); - } - break; - } - } - auto resolved = val->resolve(); - if (resolved != val) - { - getValID(resolved); - } - - ValID id = vals.getCount(); - vals.add(val); - mapValToID.add(val, id); - return id; + serialize(serializer, id); + value = getDiagnosticsLookup()->getDiagnosticById(id); } +} - void encodePtr(Val* val) +void serialize(Serializer const& serializer, SemanticVersion& value) +{ + auto raw = value.getRawValue(); + serialize(serializer, raw); + value = SemanticVersion::fromRaw(raw); +} + +void serialize(Serializer const& serializer, SyntaxClass<NodeBase>& value) +{ + ASTNodeType raw; + if (isWriting(serializer)) { - ValID id = getValID(val); - encoder->encode(id); + raw = value.getTag(); } - - void encodeImportedDecl(ImportedDeclInfo const& info) + serialize(serializer, raw); + if (isReading(serializer)) { - Encoder::WithKeyValuePair withPair(encoder); - encode(info.moduleIndex); - auto decl = info.decl; - if (auto importedModuleDecl = as<ModuleDecl>(decl)) - { - SLANG_ASSERT(info.moduleIndex == -1); - encode(importedModuleDecl->getName()); - } - else - { - auto mangledName = getMangledName(getCurrentASTBuilder(), decl); - encode(mangledName); - } + value = SyntaxClass<NodeBase>(raw); } +} - void encodePtr(Modifier* modifier) { encodeASTNodeContent(modifier); } - void encodePtr(Expr* expr) { encodeASTNodeContent(expr); } - void encodePtr(Stmt* stmt) { encodeASTNodeContent(stmt); } +// +// Many types in the AST need additional context (beyond +// what the `Serializer` has) in order to serialize +// themselves or their members. +// +// We define a custom serializer interface to capture +// the cases that can't be handled by a `Serializer` +// alone. +// - void encodePtr(Name* name) { encode(name->text); } +/// Interface for AST serialization +struct ASTSerializerImpl +{ +public: + virtual void handleASTNode(NodeBase*& value) = 0; + virtual void handleASTNodeContents(NodeBase* value) = 0; + virtual void handleName(Name*& value) = 0; + virtual void handleSourceLoc(SourceLoc& value) = 0; + virtual void handleToken(Token& value) = 0; - void encodePtr(MarkupEntry* entry) - { - // TODO: is this case needed? - SLANG_UNUSED(entry); - } + // Note that this type does *not* inherit from `ISerializerImpl`. + // + // We want to decouple the AST-specific context information + // from the lower-level details of the serialization format. + // + // Instead of using inheritance, we expect that any + // `ASTSerializerImpl` will aggregate a lower-level + // serializer, and the interface exposes access to + // that base serializer implementation. - void encodePtr(DeclAssociationList* list) - { - // We serialize this as if it were a simple list - // of key-value pairs because... well... that's - // what it amounts to in practice. - // - Encoder::WithArray withArray(encoder); - for (auto association : list->associations) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(association->kind); - encode(association->decl); - } - } + virtual ISerializerImpl* getBaseSerializer() = 0; +}; - void encodePtr(CandidateExtensionList* list) { encode(list->candidateExtensions); } +/// Specialization of `Serializer_` for AST serialization. +template<> +struct Serializer_<ASTSerializerImpl> : SerializerBase<ASTSerializerImpl> +{ +public: + using SerializerBase::SerializerBase; - void encodePtr(WitnessTable* witnessTable) - { - Encoder::WithObject withObject(encoder); - encode(witnessTable->baseType); - encode(witnessTable->witnessedType); - encode(witnessTable->isExtern); - - // TODO(tfoley): In theory we should be able to streamline - // this so that we only encode the requirements that we - // absolutely need to (which basically amounts to `associatedtype` - // requirements where the satisfying type is part of the public - // API of the type). - // - encode(witnessTable->m_requirementDictionary); - } + // + // In order to allow an `ASTSerializer` to be used with + // functions that expect an ordinary `Serializer`, we + // implement an implicit conversion operator. + // - void encodeValue(RequirementWitness const& witness) - { - Encoder::WithKeyValuePair withPair(encoder); - encodeEnum(witness.m_flavor); - switch (witness.m_flavor) - { - case RequirementWitness::Flavor::none: - break; + operator Serializer() const { return Serializer(get()->getBaseSerializer()); } +}; - case RequirementWitness::Flavor::declRef: - encode(witness.m_declRef); - break; +/// Context type for AST serialization. +using ASTSerializer = Serializer_<ASTSerializerImpl>; - case RequirementWitness::Flavor::val: - encode(witness.m_val); - break; +template<typename T> +void serializeObject(ASTSerializer const& serializer, T*& value, NodeBase*) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serializer->handleASTNode(*(NodeBase**)&value); +} - case RequirementWitness::Flavor::witnessTable: - encode((WitnessTable*)witness.m_obj.Ptr()); - break; - } - } +void serializeObjectContents(ASTSerializer const& serializer, NodeBase* value, NodeBase*) +{ + serializer->handleASTNodeContents(value); +} - void encodePtr(DiagnosticInfo* info) { encode(Int(info->id)); } +template<typename T> +void serialize(ASTSerializer const& serializer, DeclRef<T>& value) +{ + serialize(serializer, value.declRefBase); +} - void encodePtr(DeclBase* declBase) +void serialize(ASTSerializer const& serializer, SourceLoc& value) +{ + serializer->handleSourceLoc(value); +} + +void serialize(ASTSerializer const& serializer, RequirementWitness& value) +{ + SLANG_SCOPED_SERIALIZER_TAGGED_UNION(serializer); + serialize(serializer, value.m_flavor); + switch (value.m_flavor) { - if (auto decl = as<Decl>(declBase)) - { - encodePtr(decl); - } - else - { - encodeASTNodeContent(declBase); - } - } + case RequirementWitness::Flavor::none: + break; - void encodeValue(UnhandledCase); + case RequirementWitness::Flavor::declRef: + serialize(serializer, value.m_declRef); + break; - void encodeValue(String const& value) { encoder->encode(value); } + case RequirementWitness::Flavor::val: + serialize(serializer, value.m_val); + break; - void encodeValue(Token const& value) - { - encode(value.type); - encode(TokenFlags(value.flags & ~TokenFlag::Name)); - encode(value.loc); - if (value.hasContent()) - encoder->encodeString(value.getContent()); - else - encode(nullptr); + case RequirementWitness::Flavor::witnessTable: + serialize(serializer, value.m_obj); + break; } +} - void encodeValue(NameLoc const& value) { encode(value.name); } - - void encodeValue(SemanticVersion value) { encoder->encode(value.getRawValue()); } +void serialize(ASTSerializer const& serializer, WitnessTable& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.baseType); + serialize(serializer, value.witnessedType); + serialize(serializer, value.isExtern); + + // TODO(tfoley): In theory we should be able to streamline + // this so that we only encode the requirements that we + // absolutely need to (which basically amounts to `associatedtype` + // requirements where the satisfying type is part of the public + // API of the type). + // + serialize(serializer, value.m_requirementDictionary); +} - void encodeValue(CapabilitySet const& value) +void serialize(Serializer const& serializer, CapabilityAtomSet& value) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) { - // While the `CapabilityTargetSets` type is a dictionary, - // in practice each entry already embeds its own key - // (the target atom), so we can encode this as just - // an array of the `CapabilityTargetSet` values. - // - Encoder::WithArray withArray(encoder); - for (auto pair : value.getCapabilityTargetSets()) + for (auto rawAtom : value) { - encode(pair.second); + auto atom = CapabilityAtom(rawAtom); + serialize(serializer, atom); } } - - void encodeValue(CapabilityTargetSet const& value) + else { - Encoder::WithKeyValuePair withPair(encoder); - encode(value.target); - - // Similar to the case for the `CapabilityTargetSets` above, - // each `CapabilityStageSet` already includes the stage atom, - // so we can simply encode the values from the dictionary. - // - Encoder::WithArray withArray(encoder); - for (auto pair : value.shaderStageSets) + while (hasElements(serializer)) { - encode(pair.second); + CapabilityAtom atom; + serialize(serializer, atom); + value.add(UInt(atom)); } } +} - void encodeValue(CapabilityStageSet const& value) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(value.stage); - encode(value.atomSet); - } +void serialize(Serializer const& serializer, CapabilityStageSet& value) +{ + serialize(serializer, value.atomSet); +} - void encodeValue(CapabilityAtomSet const& value) +void serialize(Serializer const& serializer, CapabilityTargetSet& value) +{ + serialize(serializer, value.shaderStageSets); + + // The value for each entry in `shaderStageSets` have + // a `stage` field that is redundant with the key for + // that entry. Rather than serialize the key as part + // of the `CapabilityStageSet` type, we instead copy + // it over from the key to the value in the case where + // we are reading. + // + if (isReading(serializer)) { - Encoder::WithArray withArray(encoder); - for (auto rawAtom : value) - { - encode(CapabilityAtom(rawAtom)); - } + for (auto& p : value.shaderStageSets) + p.second.stage = p.first; } +} - template<typename T> - void encodeValue(std::optional<T> const& value) +void serialize(Serializer const& serializer, CapabilitySet& value) +{ + serialize(serializer, value.getCapabilityTargetSets()); + + // The value for each entry in `getCapabilityTargetSets()` have + // a `target` field that is redundant with the key for + // that entry. Rather than serialize the key as part + // of the `CapabilityTargetSet` type, we instead copy + // it over from the key to the value in the case where + // we are reading. + // + if (isReading(serializer)) { - if (value) - encodeValue(*value); - else - encoder->encode(nullptr); + for (auto& p : value.getCapabilityTargetSets()) + p.second.target = p.first; } +} + +void serialize(ASTSerializer const& serializer, CandidateExtensionList& value) +{ + serialize(serializer, value.candidateExtensions); +} + +void serialize(ASTSerializer const& serializer, DeclAssociation& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.kind); + serialize(serializer, value.decl); +} - void encodeValue(SyntaxClass<NodeBase> const& value) { encode(value.getTag()); } +void serialize(ASTSerializer const& serializer, DeclAssociationList& value) +{ + serialize(serializer, value.associations); +} - template<typename T> - void encodeValue(DeclRef<T> const& value) +void serialize(ASTSerializer const& serializer, Modifiers& value) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) { - encode((DeclRefBase*)value); + for (auto modifier : value) + { + serialize(serializer, modifier); + } } - - void encodeValue(ValNodeOperand value) + else { - Encoder::WithKeyValuePair withPair(encoder); + Modifier** link = &value.first; - encodeEnum(value.kind); - switch (value.kind) + while (hasElements(serializer)) { - case ValNodeOperandKind::ConstantValue: - encode(value.values.intOperand); - break; - - case ValNodeOperandKind::ValNode: - encode(static_cast<Val*>(value.values.nodeOperand)); - break; + Modifier* modifier = nullptr; + serialize(serializer, modifier); - case ValNodeOperandKind::ASTNode: - { - if (auto decl = as<Decl>(value.values.nodeOperand)) - { - encode(decl); - } - else - { - SLANG_UNEXPECTED("AST node operand of `Val` was expected to be a `Decl`"); - } - } - break; + *link = modifier; + link = &modifier->next; } } +} - void encodeValue(TypeExp value) { encode(value.type); } +void serialize(ASTSerializer const& serializer, TypeExp& value) +{ + serialize(serializer, value.type); +} - void encodeValue(QualType value) - { - Encoder::WithObject withObject(encoder); - encode(value.type); - encode(value.isLeftValue); - encode(value.hasReadOnlyOnTarget); - encode(value.isWriteOnly); - } +void serialize(ASTSerializer const& serializer, QualType& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.isLeftValue); + serialize(serializer, value.hasReadOnlyOnTarget); + serialize(serializer, value.isWriteOnly); +} - void encodeValue(MatrixCoord value) - { - Encoder::WithObject withObject(encoder); - encode(value.row); - encode(value.col); - } +void serialize(ASTSerializer const& serializer, Token& value) +{ + serializer->handleToken(value); +} - void encodeValue(SPIRVAsmOperand::Flavor const& value) { encodeEnum(value); } +void serialize(ASTSerializer const& serializer, SPIRVAsmOperand& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.flavor); + serialize(serializer, value.token); + serialize(serializer, value.expr); + serialize(serializer, value.bitwiseOrWith); + serialize(serializer, value.knownValue); + serialize(serializer, value.wrapInId); + serialize(serializer, value.type); +} - void encodeValue(SPIRVAsmOperand const& value) - { - Encoder::WithObject withObject(encoder); - encode(value.flavor); - encode(value.token); - encode(value.expr); - encode(value.bitwiseOrWith); - encode(value.knownValue); - encode(value.wrapInId); - encode(value.type); - } +void serialize(ASTSerializer const& serializer, SPIRVAsmInst& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.opcode); + serialize(serializer, value.operands); +} - void encodeValue(SPIRVAsmInst const& value) +void serialize(ASTSerializer const& serializer, ValNodeOperand& value) +{ + SLANG_SCOPED_SERIALIZER_TAGGED_UNION(serializer); + serialize(serializer, value.kind); + switch (value.kind) { - Encoder::WithObject withObject(encoder); - encode(value.opcode); - encode(value.operands); - } + case ValNodeOperandKind::ConstantValue: + serialize(serializer, value.values.intOperand); + break; - - template<typename T, typename = std::enable_if_t<std::is_same_v<T, bool>>> - void encodeValue(T value) - { - encoder->encodeBool(value); + case ValNodeOperandKind::ValNode: + case ValNodeOperandKind::ASTNode: + serialize(serializer, value.values.nodeOperand); + break; } +} - void encodeValue(Int32 value) { encoder->encode(value); } - void encodeValue(UInt32 value) { encoder->encode(value); } - void encodeValue(Int64 value) { encoder->encode(value); } - void encodeValue(UInt64 value) { encoder->encode(value); } - void encodeValue(float value) { encoder->encode(value); } - void encodeValue(double value) { encoder->encode(value); } - - void encodeValue(uint8_t value) { encoder->encode(UInt32(value)); } - - void encodeValue(std::nullptr_t) { encoder->encode(nullptr); } +void serializeObject(ASTSerializer const& serializer, Name*& value, Name*) +{ + serializer->handleName(value); +} - template<typename T> - void encodeEnum(T value) - { - encoder->encode(Int32(value)); - } +void serialize(ASTSerializer const& serializer, NameLoc& value) +{ + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.name); + serialize(serializer, value.loc); +} - void encodeValue(DeclVisibility value) { encodeEnum(value); } - void encodeValue(BaseType value) { encodeEnum(value); } - void encodeValue(BuiltinRequirementKind value) { encodeEnum(value); } - void encodeValue(ASTNodeType value) { encodeEnum(value); } - void encodeValue(ImageFormat value) { encodeEnum(value); } - void encodeValue(TypeTag value) { encodeEnum(value); } - void encodeValue(TryClauseType value) { encodeEnum(value); } - void encodeValue(CapabilityAtom value) { encodeEnum(value); } - void encodeValue(DeclAssociationKind value) { encodeEnum(value); } - void encodeValue(TokenType value) { encodeEnum(value); } - - void encodeValue(SourceLoc value) - { - if (!_sourceLocWriter) - { - encoder->encode(nullptr); - } - else - { - auto intermediate = _sourceLocWriter->addSourceLoc(value); - encoder->encode(intermediate); +#if 0 // FIDDLE TEMPLATE: +%for _,T in ipairs(Slang.NodeBase.subclasses) do +void _serializeASTNodeContents(ASTSerializer const& serializer, $T* value) +{ + SLANG_UNUSED(serializer); + SLANG_UNUSED(value); +% if T.directSuperClass then + _serializeASTNodeContents(serializer, static_cast<$(T.directSuperClass)*>(value)); +% end +% for _,f in ipairs(T.directFields) do + serialize(serializer, value->$f); +% end } - } +%end +#else // FIDDLE OUTPUT: +#define FIDDLE_GENERATED_OUTPUT_ID 0 +#include "slang-serialize-ast.cpp.fiddle" +#endif // FIDDLE END - template<typename T> - void encodeValue(T const* ptr) - { - if (!ptr) - { - encoder->encode(nullptr); - } - else - { - encodePtr(const_cast<T*>(ptr)); - } - } +void serializeASTNodeContents(ASTSerializer const& serializer, NodeBase* node) +{ + ASTNodeDispatcher<NodeBase, void>::dispatch( + node, + [&](auto n) { _serializeASTNodeContents(serializer, n); }); +} - template<typename T> - void encodeValue(RefPtr<T> const& ptr) - { - if (!ptr) - { - encoder->encode(nullptr); - } - else - { - encodePtr(ptr.Ptr()); - } - } +enum class PseudoASTNodeType +{ + None, + ImportedModule, + ImportedDecl, +}; - void encodeValue(Modifiers const& modifiers) - { - Encoder::WithArray withArray(encoder); - for (auto m : const_cast<Modifiers&>(modifiers)) - { - encode(m); - } - } +static PseudoASTNodeType _getPseudoASTNodeType(ASTNodeType type) +{ + return int(type) < 0 ? PseudoASTNodeType(~int(type)) : PseudoASTNodeType::None; +} + +static ASTNodeType _getAsASTNodeType(PseudoASTNodeType type) +{ + return ASTNodeType(~int(type)); +} - template<typename T, int N> - void encodeValue(ShortList<T, N> const& array) +struct ASTEncodingContext : ASTSerializerImpl +{ +public: + ASTEncodingContext( + RIFF::BuildCursor& cursor, + ModuleDecl* module, + SerialSourceLocWriter* sourceLocWriter) + : _writer(cursor.getCurrentChunk()), _module(module), _sourceLocWriter(sourceLocWriter) { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } } +private: + RIFFSerialWriter _writer; + ModuleDecl* _module = nullptr; + SerialSourceLocWriter* _sourceLocWriter = nullptr; - template<typename T> - void encode(List<T> const& array) - { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } - } + virtual ISerializerImpl* getBaseSerializer() override { return &_writer; } - template<typename T, size_t N> - void encode(T const (&array)[N]) - { - Encoder::WithArray withArray(encoder); - for (auto element : array) - { - encode(element); - } - } + virtual void handleName(Name*& value) override; + virtual void handleSourceLoc(SourceLoc& value) override; + virtual void handleToken(Token& value) override; + virtual void handleASTNode(NodeBase*& node) override; + virtual void handleASTNodeContents(NodeBase* node) override; - template<typename K, typename V> - void encode(OrderedDictionary<K, V> const& dictionary) - { - Encoder::WithArray withArray(encoder); - for (auto p : dictionary) - { - Encoder::WithKeyValuePair withPair(encoder); - encode(p.key); - encode(p.value); - } - } + void _writeImportedModule(ModuleDecl* moduleDecl); + void _writeImportedDecl(Decl* decl, ModuleDecl* importedFromModuleDecl); - template<typename K, typename V> - void encode(Dictionary<K, V> const& dictionary) + ModuleDecl* _findModuleForDecl(Decl* decl) { - Encoder::WithArray withArray(encoder); - for (auto p : dictionary) + for (auto d = decl; d; d = d->parentDecl) { - Encoder::WithKeyValuePair withPair(encoder); - encode(p.first); - encode(p.second); + if (auto m = as<ModuleDecl>(d)) + return m; } + return nullptr; } - template<typename T> - void encode(T const& value) + ModuleDecl* _findModuleDeclWasImportedFrom(Decl* decl) { - encodeValue(value); - } - - // for each class of node, we generate - // code to recursively serialize each - // of its fields. - -#if 0 // FIDDLE TEMPLATE: -%for _,T in ipairs(Slang.NodeBase.subclasses) do - void _encodeDataOf($T* obj) - { -%if T.directSuperClass then - _encodeDataOf(static_cast<$(T.directSuperClass)*>(obj)); -%end -%for _,f in ipairs(T.directFields) do - encode(obj->$f); -%end + auto declModule = _findModuleForDecl(decl); + if (declModule == nullptr) + return nullptr; + if (declModule == _module) + return nullptr; + return declModule; } -%end -#else // FIDDLE OUTPUT: -#define FIDDLE_GENERATED_OUTPUT_ID 0 -#include "slang-serialize-ast.cpp.fiddle" -#endif // FIDDLE END }; -void writeSerializedModuleAST( - Encoder* encoder, - ModuleDecl* moduleDecl, - SerialSourceLocWriter* sourceLocWriter) -{ - Encoder::WithObject withObject(encoder); - - // TODO: we should have a more careful pass here, - // where we only encode the public declarations - // - - ASTEncodingContext context(encoder, moduleDecl, sourceLocWriter); - context.getDeclID(moduleDecl); - context.flush(); -} - -struct ASTDecodingContext +struct ASTDecodingContext : ASTSerializerImpl { public: ASTDecodingContext( @@ -673,284 +532,60 @@ public: : _linkage(linkage) , _astBuilder(astBuilder) , _sink(sink) - , _baseChunk(as<RIFF::ListChunk>(baseChunk)) , _sourceLocReader(sourceLocReader) , _requestingSourceLoc(requestingSourceLoc) + , _riffReader(baseChunk) { } +private: Linkage* _linkage = nullptr; + ASTBuilder* _astBuilder = nullptr; DiagnosticSink* _sink = nullptr; SerialSourceLocReader* _sourceLocReader = nullptr; SourceLoc _requestingSourceLoc; + RIFFSerialReader _riffReader; - SlangResult decodeAll() - { - auto cursor = _baseChunk->getChildren().begin(); - - // There are a few different top-level chunks that - // hold different arrays that we need in order - // to decode the entire module hierarchy. - // - // Basically, these lists correspond to the kinds - // of nodes in the AST hierarchy for which back-references - // are allowed (all other nodes should, barring - // weird corner cases, form a single tree-structured - // ownership hierarchy, rooted at the `ModuleDecl`. - // - - // First there is the list that actually encodes - // for the declarations in the module, including - // the `ModuleDecl` itself, which should be the - // first entry in the list. - // - auto declChunk = *cursor; - ++cursor; - - // Next there is a list of all the declarations - // referenced inside of the module that need to - // be imported in from outside. - // - auto importedDeclChunk = *cursor; - ++cursor; + virtual ISerializerImpl* getBaseSerializer() override { return &_riffReader; } - // Then there are all the `Val`-derived nodes that - // are needed by the module, which will need to be - // deduplicated so that they are unique within the - // current compilation context. - // - auto valChunk = *cursor; - ++cursor; + virtual void handleName(Name*& value) override; + virtual void handleSourceLoc(SourceLoc& value) override; + virtual void handleToken(Token& value) override; + virtual void handleASTNode(NodeBase*& outNode) override; + virtual void handleASTNodeContents(NodeBase* node) override; - // The process of decoding the module is then spread - // over a number of steps. - // - // The first step is to process all of the imported - // declarations, so that other nodes can refer to - // them. - // - SLANG_RETURN_ON_FAIL(decodeImportedDecls(importedDeclChunk)); - - // Next we process the declarations that are within - // the module itself, first creating an "empty shell" - // of each declaration that has the right size in - // memory (and the right `ASTNodeType` tag), so that - // we can wire up references to it (including circular - // references)... so long as nothing here tries to - // look *inside* the empty shell along the way. - // - SLANG_RETURN_ON_FAIL(createEmptyShells(declChunk)); - - // Once all the `Decl`s that might be needed have - // been allocated, we can process all the `Val`s - // that might reference those`Decl`s (and one another). - // - // The nature of the `Val` representation ensures - // that there cannot be cirularities in the references - // between `Val`s, and the encoding process will have - // sorted the entries so that a `Val` only ever appears - // *after* its operands. - // - SLANG_RETURN_ON_FAIL(decodeVals(valChunk)); + ModuleDecl* _readImportedModule(); + NodeBase* _readImportedDecl(); - // Once all the back-reference-able objects have been - // instantiated in memory, we can go back through the - // `Decl`s in the module and fill in those empty shells. - // - SLANG_RETURN_ON_FAIL(fillEmptyShells(declChunk)); - - // As a final pass, we perform any special cleanup actions - // that might be required to make the output valid for consumers. - // - // For example, this is where we set the `DeclCheckState` of everything - // we are loading to reflect the fact that everything we deserialize - // is (supposed to be) fully cheked. - // - SLANG_RETURN_ON_FAIL(cleanUpNodes()); - - - return SLANG_OK; - } - - typedef Int DeclID; - Decl* getDeclByID(DeclID id) + void _cleanUpASTNode(NodeBase* node) { - if (id >= 0) - { - return _decls[id]; - } - else + if (auto expr = as<Expr>(node)) { - return _importedDecls[~id]; + expr->checked = true; } - } - -private: - struct UnhandledCase - { - }; - - ASTBuilder* _astBuilder = nullptr; - RIFF::ListChunk const* _baseChunk = nullptr; - - List<Decl*> _decls; - List<Decl*> _importedDecls; - List<Val*> _vals; - - typedef Int ValID; - Val* getValByID(ValID id) { return _vals[id]; } - - SlangResult decodeImportedDecls(RIFF::Chunk const* importedDeclChunk) - { - Decoder decoder(importedDeclChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) + else if (auto decl = as<Decl>(node)) { - Decoder::WithKeyValuePair withPair(decoder); - - Int moduleIndex; - decode(moduleIndex, decoder); + decl->checkState = DeclCheckState::CapabilityChecked; - if (moduleIndex == -1) + if (auto genericDecl = as<GenericDecl>(node)) { - Name* moduleName = nullptr; - decode(moduleName, decoder); - - Decl* importedModule = getImportedModule(moduleName); - _importedDecls.add(importedModule); + _assignGenericParameterIndices(genericDecl); } - else + else if (auto syntaxDecl = as<SyntaxDecl>(node)) { - auto importedFromModuleDecl = as<ModuleDecl>(_importedDecls[moduleIndex]); - auto importedFromModule = importedFromModuleDecl->module; - - String mangledName; - decode(mangledName, decoder); - - auto importedNode = - importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); - auto importedDecl = as<Decl>(importedNode); - _importedDecls.add(importedDecl); + syntaxDecl->parseCallback = &parseSimpleSyntax; + syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); } - } - return SLANG_OK; - } - - ModuleDecl* getImportedModule(Name* moduleName) - { - Module* module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); - if (!module) - { - SLANG_ABORT_COMPILATION("failed to load an imported module during deserialization"); - } - - return module->getModuleDecl(); - } - - SlangResult decodeVals(RIFF::Chunk const* valChunk) - { - Decoder decoder(valChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Val* val = decodeValNode(decoder); - _vals.add(val); - } - return SLANG_OK; - } - - SlangResult createEmptyShells(RIFF::Chunk const* declChunk) - { - Decoder decoder(declChunk); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - ASTNodeType nodeType; - - // Each of the declarations is expected to take - // the form of an object with a first field - // that holds the node type. - // + else if (auto namespaceLikeDecl = as<NamespaceDeclBase>(node)) { - Decoder::WithObject withObject(decoder); - decode(nodeType, decoder); + auto declScope = _astBuilder->create<Scope>(); + declScope->containerDecl = namespaceLikeDecl; + namespaceLikeDecl->ownedScope = declScope; } - - auto emptyShell = createEmptyShell(nodeType); - auto declEmptyShell = as<Decl>(emptyShell); - _decls.add(declEmptyShell); } - - return SLANG_OK; } - Val* decodeValNode(Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - - ASTNodeType nodeType; - decode(nodeType, decoder); - - ValNodeDesc desc; - desc.type = SyntaxClass<NodeBase>(nodeType); - - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - ValNodeOperand operand; - decode(operand, decoder); - desc.operands.add(operand); - } - - desc.init(); - - auto val = _astBuilder->_getOrCreateImpl(_Move(desc)); - - // Values created during deserialization are - // not expected to ever resolve further, because - // they should be coming from fully checked code. - // - // val->resolve(); - // val->_setUnique(); - - return val; - } - - NodeBase* createEmptyShell(ASTNodeType nodeType) - { - return SyntaxClass<NodeBase>(nodeType).createInstance(_astBuilder); - } - - SlangResult fillEmptyShells(RIFF::Chunk const* declChunk) - { - Index declIndex = 0; - - Decoder decoder(declChunk); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - auto declEmptyShell = _decls[declIndex++]; - decodeASTNodeContent(declEmptyShell, decoder); - } - - return SLANG_OK; - } - - SlangResult cleanUpNodes() - { - for (auto decl : _decls) - { - decl->checkState = DeclCheckState::CapabilityChecked; - } - - return SLANG_OK; - } - - - void assignGenericParameterIndices(GenericDecl* genericDecl) + void _assignGenericParameterIndices(GenericDecl* genericDecl) { int parameterCounter = 0; for (auto m : genericDecl->members) @@ -965,576 +600,314 @@ private: } } } +}; +// +// We are matching up the corresponding `handle*()` operations from the +// `AST{Encoding|Decoding}Context` types here, so that it is easier +// to visually verify that they are serializing the same data with the +// same ordering. +// - void cleanUpASTNode(NodeBase* node) - { - if (auto expr = as<Expr>(node)) - { - expr->checked = true; - } - else if (auto genericDecl = as<GenericDecl>(node)) - { - assignGenericParameterIndices(genericDecl); - } - else if (auto syntaxDecl = as<SyntaxDecl>(node)) - { - syntaxDecl->parseCallback = &parseSimpleSyntax; - syntaxDecl->parseUserData = (void*)syntaxDecl->syntaxClass.getInfo(); - } - else if (auto namespaceLikeDecl = as<NamespaceDeclBase>(node)) - { - auto declScope = _astBuilder->create<Scope>(); - declScope->containerDecl = namespaceLikeDecl; - namespaceLikeDecl->ownedScope = declScope; - } - } - - void decodeASTNodeContent(NodeBase* node, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - - ASTNodeDispatcher<NodeBase, void>::dispatch( - node, - [&](auto n) { _decodeDataOf(n, decoder); }); +// +// AST{Encoding|Decoding}Context::handleName() +// - cleanUpASTNode(node); - } +void ASTEncodingContext::handleName(Name*& value) +{ + serialize(ASTSerializer(this), value->text); +} - DeclID decodeDeclID(Decoder& decoder) - { - DeclID result = decoder.decode<DeclID>(); - return result; - } +void ASTDecodingContext::handleName(Name*& value) +{ + String text; + serialize(ASTSerializer(this), text); + value = _astBuilder->getNamePool()->getName(text); +} - ValID decodeValID(Decoder& decoder) - { - ValID result = decoder.decode<ValID>(); - return result; - } +// +// AST{Encoding|Decoding}Context::handleSourceLoc() +// - template<typename T> - void decodeASTNode(T*& node, Decoder& decoder) +void ASTEncodingContext::handleSourceLoc(SourceLoc& value) +{ + ASTSerializer serializer(this); + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (_sourceLocWriter != nullptr) { - ASTNodeType nodeType; - auto saved = decoder.getCursor(); - { - Decoder::WithObject withObject(decoder); - decode(nodeType, decoder); - } - decoder.setCursor(saved); - - auto shell = createEmptyShell(nodeType); - decodeASTNodeContent(shell, decoder); - - node = as<T>(shell); + auto rawValue = _sourceLocWriter->addSourceLoc(value); + serialize(serializer, rawValue); } +} - void decodePtr(Name*& name, Decoder& decoder, Name*) +void ASTDecodingContext::handleSourceLoc(SourceLoc& value) +{ + ASTSerializer serializer(this); + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (hasElements(serializer)) { - String text; - decode(text, decoder); - - name = _astBuilder->getNamePool()->getName(text); - } + SerialSourceLocData::SourceLoc rawValue; + serialize(serializer, rawValue); - void decodePtr(DeclAssociationList*& outList, Decoder& decoder, DeclAssociationList*) - { - // Mirroring the encoding logic, we decode this - // as a list of key-value pairs. - // - auto list = RefPtr(new DeclAssociationList()); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) + if (_sourceLocReader) { - auto association = RefPtr(new DeclAssociation()); - - Decoder::WithKeyValuePair withPair(decoder); - decode(association->kind, decoder); - decode(association->decl, decoder); - - list->associations.add(association); + value = _sourceLocReader->getSourceLoc(rawValue); } - - outList = list.detach(); } +} - void decodePtr(DiagnosticInfo const*& info, Decoder& decoder, DiagnosticInfo const*) - { - Int id; - decode(id, decoder); - info = getDiagnosticsLookup()->getDiagnosticById(id); - } - - void decodePtr(MarkupEntry*& markupEntry, Decoder&, MarkupEntry*) - { - // TODO: is this case needed? - markupEntry = nullptr; - } - - void decodePtr(CandidateExtensionList*& list, Decoder& decoder, CandidateExtensionList*) - { - auto result = RefPtr(new CandidateExtensionList()); - decode(result->candidateExtensions, decoder); - list = result.detach(); - } - - void decodePtr(WitnessTable*& witnessTable, Decoder& decoder, WitnessTable*) - { - Decoder::WithObject withObject(decoder); - auto wt = RefPtr(new WitnessTable()); - decode(wt->baseType, decoder); - decode(wt->witnessedType, decoder); - decode(wt->isExtern, decoder); - decode(wt->m_requirementDictionary, decoder); - witnessTable = wt.detach(); - } - - void decodeValue(RequirementWitness& witness, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decodeEnum(witness.m_flavor, decoder); - switch (witness.m_flavor) - { - case RequirementWitness::Flavor::none: - break; - - case RequirementWitness::Flavor::declRef: - decode(witness.m_declRef, decoder); - break; - - case RequirementWitness::Flavor::val: - decode(witness.m_val, decoder); - break; +// +// AST{Encoding|Decoding}Context::handleToken() +// - case RequirementWitness::Flavor::witnessTable: - { - RefPtr<WitnessTable> object; - decode(object, decoder); - witness.m_obj = object; - } - break; - } - } +void ASTDecodingContext::handleToken(Token& value) +{ + ASTSerializer serializer(this); - template<typename T> - void decodePtr(T*& node, Decoder& decoder, Val*) - { - ValID id = decodeValID(decoder); - node = static_cast<T*>(getValByID(id)); - } + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.loc); - template<typename T> - void decodePtr(T*& node, Decoder& decoder, Decl*) - { - DeclID id = decodeDeclID(decoder); - node = static_cast<T*>(getDeclByID(id)); - } + serialize(serializer, value.flags); - template<typename T> - void decodePtr(T*& node, Decoder& decoder, DeclBase*) { - // This case is a bit of a hack. We need - // to identify whether we are looking at - // an indirection to a `Decl` (which would - // be serialized as an integer `DeclID`), - // or something else derived from `DeclBase`. - // - switch (decoder.getTag()) + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (hasElements(serializer)) { - default: - decodeASTNode(node, decoder); - break; - - case SerialBinary::kInt32FourCC: - case SerialBinary::kInt64FourCC: - case SerialBinary::kUInt32FourCC: - case SerialBinary::kUInt64FourCC: - { - DeclID id = decodeDeclID(decoder); - node = static_cast<T*>(getDeclByID(id)); - } - break; - } - } - - template<typename T> - void decodePtr(T*& node, Decoder& decoder, NodeBase*) - { - decodeASTNode(node, decoder); - } + String content; + serialize(serializer, content); - - void decodeValue(UnhandledCase, Decoder& decoder); - - void decodeValue(String& value, Decoder& decoder) { value = decoder.decodeString(); } - - void decodeValue(Token& value, Decoder& decoder) - { - decode(value.type, decoder); - decode(value.flags, decoder); - decode(value.loc, decoder); - if (decoder.decodeNull()) - { - } - else - { - Name* name = nullptr; - decode(name, decoder); + // An important note here is that we cannot just + // call `value.setContent(...)` and pass in an + // `UnownedStringSlice` of `content`, because the + // `Token` will not take ownership of its own + // textual content. + // + // Instead, we need to get the text we just loaded + // into something that the `Token` can refer info, + // and the easiest way to accomplish that is to + // represent the text using a `Name`. + // + Name* name = _astBuilder->getNamePool()->getName(content); value.setName(name); } } +} - void decodeValue(NameLoc& value, Decoder& decoder) { decode(value.name, decoder); } - - void decodeValue(SemanticVersion& value, Decoder& decoder) - { - SemanticVersion::RawValue rawValue = decoder.decode<SemanticVersion::RawValue>(); - value.setRawValue(rawValue); - } - - void decodeValue(CapabilitySet& value, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - CapabilityTargetSet targetSet; - decode(targetSet, decoder); - value.getCapabilityTargetSets()[targetSet.target] = targetSet; - } - } - - void decodeValue(CapabilityTargetSet& value, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decode(value.target, decoder); +void ASTEncodingContext::handleToken(Token& value) +{ + ASTSerializer serializer(this); - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - CapabilityStageSet stageSet; - decode(stageSet, decoder); - value.shaderStageSets[stageSet.stage] = stageSet; - } - } + SLANG_SCOPED_SERIALIZER_STRUCT(serializer); + serialize(serializer, value.type); + serialize(serializer, value.loc); - void decodeValue(CapabilityStageSet& value, Decoder& decoder) - { - Decoder::WithKeyValuePair withPair(decoder); - decode(value.stage, decoder); - decode(value.atomSet, decoder); - } + TokenFlags flags = TokenFlags(value.flags & ~TokenFlag::Name); + serialize(serializer, flags); - void decodeValue(CapabilityAtomSet& value, Decoder& decoder) { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - CapabilityAtom atom; - decode(atom, decoder); - value.add(UInt(atom)); - } - } - - template<typename T> - void decodeValue(std::optional<T>& outValue, Decoder& decoder) - { - if (decoder.decodeNull()) - { - outValue.reset(); - } - else + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (value.hasContent()) { - T value; - decode(value, decoder); - outValue = value; + String content = value.getContent(); + serialize(serializer, content); } } +} - void decodeValue(SyntaxClass<NodeBase>& syntaxClass, Decoder& decoder) - { - ASTNodeType nodeType; - decode(nodeType, decoder); - syntaxClass = SyntaxClass<NodeBase>(nodeType); - } - - template<typename T> - void decodeValue(DeclRef<T>& declRef, Decoder& decoder) - { - decode(declRef.declRefBase, decoder); - } +// +// AST{Encoding|Decoding}Context::handleASTNode() +// - void decodeValue(ValNodeOperand& value, Decoder& decoder) +void ASTEncodingContext::handleASTNode(NodeBase*& node) +{ + if (auto decl = as<Decl>(node)) { - Decoder::WithKeyValuePair withPair(decoder); - - decodeEnum(value.kind, decoder); - switch (value.kind) + if (auto importedFromModule = _findModuleDeclWasImportedFrom(decl)) { - case ValNodeOperandKind::ConstantValue: - decode(value.values.intOperand, decoder); - break; - - case ValNodeOperandKind::ValNode: + if (decl == importedFromModule) { - Val* val = nullptr; - decode(val, decoder); - value.values.nodeOperand = val; + _writeImportedModule(importedFromModule); + return; } - break; - - case ValNodeOperandKind::ASTNode: + else { - Decl* decl = nullptr; - decode(decl, decoder); - value.values.nodeOperand = decl; + _writeImportedDecl(decl, importedFromModule); + return; } - break; } } - void decodeValue(TypeExp& value, Decoder& decoder) { decode(value.type, decoder); } + ASTSerializer serializer(this); - void decodeValue(QualType& value, Decoder& decoder) + if (auto val = as<Val>(node)) { - Decoder::WithObject withObject(decoder); - decode(value.type, decoder); - decode(value.isLeftValue, decoder); - decode(value.hasReadOnlyOnTarget, decoder); - decode(value.isWriteOnly, decoder); - } + val = val->resolve(); - void decodeValue(MatrixCoord& value, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - decode(value.row, decoder); - decode(value.col, decoder); + // On the reading side of things, sublcasses of `Val` + // are deduplicated as part of creation, and will read the + // operands out immediately, so we mirror that approach + // on the writing side to make sure the code is consistent. + // + serialize(serializer, val->astNodeType); + serialize(serializer, val->m_operands); } - - void decodeValue(SPIRVAsmOperand::Flavor& value, Decoder& decoder) + else { - decodeEnum(value, decoder); + serialize(serializer, node->astNodeType); + deferSerializeObjectContents(serializer, node); } +} - void decodeValue(SPIRVAsmOperand& value, Decoder& decoder) - { - Decoder::WithObject withObject(decoder); - decode(value.flavor, decoder); - decode(value.token, decoder); - decode(value.expr, decoder); - decode(value.bitwiseOrWith, decoder); - decode(value.knownValue, decoder); - decode(value.wrapInId, decoder); - decode(value.type, decoder); - } +void ASTDecodingContext::handleASTNode(NodeBase*& outNode) +{ + ASTSerializer serializer(this); - void decodeValue(SPIRVAsmInst& value, Decoder& decoder) + ASTNodeType typeTag; + serialize(serializer, typeTag); + switch (_getPseudoASTNodeType(typeTag)) { - Decoder::WithObject withObject(decoder); - decode(value.opcode, decoder); - decode(value.operands, decoder); - } + default: + break; + case PseudoASTNodeType::ImportedModule: + outNode = _readImportedModule(); + return; - template<typename T> - void decodeEnum(T& value, Decoder& decoder) - { - value = T(decoder.decode<Int32>()); + case PseudoASTNodeType::ImportedDecl: + outNode = _readImportedDecl(); + return; } - template<typename T> - void decodeSimpleValue(T& value, Decoder& decoder) + auto syntaxClass = SyntaxClass<NodeBase>(typeTag); + if (syntaxClass.isSubClassOf<Val>()) { - value = decoder.decode<T>(); - } + // Subclasses of `Val` are deduplicated as part + // of creation, so we need to read in their + // operands before we can create them, rather + // than allocating the object up front and + // then deserializing its content into it later. - void decodeValue(bool& value, Decoder& decoder) { value = decoder.decodeBool(); } - void decodeValue(Int32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(Int64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(UInt32& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(UInt64& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(float& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } - void decodeValue(double& value, Decoder& decoder) { decodeSimpleValue(value, decoder); } + ValNodeDesc desc; + desc.type = syntaxClass; + serialize(serializer, desc.operands); - void decodeValue(uint8_t& value, Decoder& decoder) - { - value = uint8_t(decoder.decode<UInt32>()); - } + desc.init(); - void decodeValue(DeclVisibility& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(BaseType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(BuiltinRequirementKind& value, Decoder& decoder) - { - decodeEnum(value, decoder); - } - void decodeValue(ASTNodeType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(ImageFormat& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TypeTag& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TryClauseType& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(CapabilityAtom& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(PreferRecomputeAttribute::SideEffectBehavior& value, Decoder& decoder) - { - decodeEnum(value, decoder); + auto node = _astBuilder->_getOrCreateImpl(std::move(desc)); + outNode = node; } - void decodeValue(LogicOperatorShortCircuitExpr::Flavor& value, Decoder& decoder) + else { - decodeEnum(value, decoder); - } - void decodeValue(TreatAsDifferentiableExpr::Flavor& value, Decoder& decoder) - { - decodeEnum(value, decoder); - } - void decodeValue(DeclAssociationKind& value, Decoder& decoder) { decodeEnum(value, decoder); } - void decodeValue(TokenType& value, Decoder& decoder) { decodeEnum(value, decoder); } + auto node = syntaxClass.createInstance(_astBuilder); + outNode = node; - - void decodeValue(SourceLoc& value, Decoder& decoder) - { - if (!decoder.decodeNull()) - { - SerialSourceLocData::SourceLoc intermediate; - decoder.decode(intermediate); - - if (_sourceLocReader) - { - auto sourceLoc = _sourceLocReader->getSourceLoc(intermediate); - value = sourceLoc; - } - } + deferSerializeObjectContents(serializer, node); } +} - template<typename T> - void decodeValue(T*& ptr, Decoder& decoder) - { - if (decoder.decodeNull()) - ptr = nullptr; - else - decodePtr(ptr, decoder, (T*)nullptr); - } +// +// AST{Encoding|Decoding}Context::handleASTNodeContents() +// - template<typename T> - void decodeValue(RefPtr<T>& ptr, Decoder& decoder) - { - if (decoder.decodeNull()) - ptr = nullptr; - else - { - // Hi Future Tess, - // - // The next step here is decoding logic for `WitnessTable`s. - // +void ASTEncodingContext::handleASTNodeContents(NodeBase* node) +{ + ASTSerializer serializer(this); + serializeASTNodeContents(serializer, node); +} - decodePtr(*ptr.writeRef(), decoder, (T*)nullptr); - } - } +void ASTDecodingContext::handleASTNodeContents(NodeBase* node) +{ + ASTSerializer serializer(this); + serializeASTNodeContents(serializer, node); - void decodeValue(Modifiers& modifiers, Decoder& decoder) - { - Modifier** link = &modifiers.first; + _cleanUpASTNode(node); +} - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Modifier* modifier = nullptr; - decode(modifier, decoder); +// +// AST{Encoding|Decoding}Context::_{write|read}ImportedModule() +// - *link = modifier; - link = &modifier->next; - } - } +void ASTEncodingContext::_writeImportedModule(ModuleDecl* moduleDecl) +{ + ASTNodeType type = _getAsASTNodeType(PseudoASTNodeType::ImportedModule); + auto moduleName = moduleDecl->getName(); - template<typename T, int N> - void decodeValue(ShortList<T, N>& array, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - T element; - decode(element, decoder); - array.add(element); - } - } + ASTSerializer serializer(this); + serialize(serializer, type); + serialize(serializer, moduleName); +} +ModuleDecl* ASTDecodingContext::_readImportedModule() +{ + ASTSerializer serializer(this); - template<typename T> - void decode(List<T>& array, Decoder& decoder) + Name* moduleName = nullptr; + serialize(serializer, moduleName); + auto module = _linkage->findOrImportModule(moduleName, _requestingSourceLoc, _sink); + if (!module) { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - T element; - decode(element, decoder); - array.add(element); - } + SLANG_ABORT_COMPILATION("failed to load an imported module during AST deserialization"); } + return module->getModuleDecl(); +} - template<typename T, size_t N> - void decode(T (&array)[N], Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - for (auto& element : array) - { - decode(element, decoder); - } - } +// +// AST{Encoding|Decoding}Context::_{write|read}ImportedModule() +// - template<typename K, typename V> - void decode(OrderedDictionary<K, V>& dictionary, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Decoder::WithKeyValuePair withPair(decoder); +void ASTEncodingContext::_writeImportedDecl(Decl* decl, ModuleDecl* importedFromModuleDecl) +{ + ASTNodeType type = _getAsASTNodeType(PseudoASTNodeType::ImportedDecl); + auto mangledName = getMangledName(getCurrentASTBuilder(), decl); - K key; - V value; - decode(key, decoder); - decode(value, decoder); + ASTSerializer serializer(this); + serialize(serializer, type); + serialize(serializer, importedFromModuleDecl); + serialize(serializer, mangledName); +} - dictionary.add(key, value); - } - } +NodeBase* ASTDecodingContext::_readImportedDecl() +{ + ASTSerializer serializer(this); - template<typename K, typename V> - void decode(Dictionary<K, V>& dictionary, Decoder& decoder) - { - Decoder::WithArray withArray(decoder); - while (decoder.hasElements()) - { - Decoder::WithKeyValuePair withPair(decoder); + ModuleDecl* importedFromModuleDecl = nullptr; + String mangledName; - K key; - V value; - decode(key, decoder); - decode(value, decoder); + serialize(serializer, importedFromModuleDecl); + serialize(serializer, mangledName); - dictionary.add(key, value); - } + auto importedFromModule = importedFromModuleDecl->module; + if (!importedFromModule) + { + return nullptr; } - template<typename T> - void decode(T& outValue, Decoder& decoder) + auto importedDecl = + importedFromModule->findExportFromMangledName(mangledName.getUnownedSlice()); + if (!importedDecl) { - decodeValue(outValue, decoder); + SLANG_ABORT_COMPILATION( + "failed to load an imported declaration during AST deserialization"); } + return importedDecl; +} -#if 0 // FIDDLE TEMPLATE: -%for _,T in ipairs(Slang.NodeBase.subclasses) do - void _decodeDataOf($T* obj, Decoder& decoder) - { -% if T.directSuperClass then - _decodeDataOf(static_cast<$(T.directSuperClass)*>(obj), decoder); -% end -% for _,f in ipairs(T.directFields) do - decode(obj->$f, decoder); -% end - } -%end -#else // FIDDLE OUTPUT: -#define FIDDLE_GENERATED_OUTPUT_ID 1 -#include "slang-serialize-ast.cpp.fiddle" -#endif // FIDDLE END -}; +// +// {write|read}SerializedModuleAST() +// + +void writeSerializedModuleAST( + RIFF::BuildCursor& cursor, + ModuleDecl* moduleDecl, + SerialSourceLocWriter* sourceLocWriter) +{ + // TODO: we might want to have a more careful pass here, + // where we only encode the public declarations. + + ASTEncodingContext context(cursor, moduleDecl, sourceLocWriter); + serialize(ASTSerializer(&context), moduleDecl); +} ModuleDecl* readSerializedModuleAST( Linkage* linkage, @@ -1546,9 +919,10 @@ ModuleDecl* readSerializedModuleAST( { ASTDecodingContext context(linkage, astBuilder, sink, chunk, sourceLocReader, requestingSourceLoc); - context.decodeAll(); - auto node = context.getDeclByID(0); - auto moduleDecl = as<ModuleDecl>(node); + + ModuleDecl* moduleDecl = nullptr; + serialize(ASTSerializer(&context), moduleDecl); return moduleDecl; } + } // namespace Slang diff --git a/source/slang/slang-serialize-ast.h b/source/slang/slang-serialize-ast.h index 45c799e9c..86ba6e772 100644 --- a/source/slang/slang-serialize-ast.h +++ b/source/slang/slang-serialize-ast.h @@ -11,8 +11,10 @@ namespace Slang { +class Linkage; + void writeSerializedModuleAST( - Encoder* encoder, + RIFF::BuildCursor& cursor, ModuleDecl* moduleDecl, SerialSourceLocWriter* sourceLocWriter); diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp index 665775dc5..ba8c74ca4 100644 --- a/source/slang/slang-serialize-container.cpp +++ b/source/slang/slang-serialize-container.cpp @@ -26,7 +26,7 @@ private: RefPtr<SerialSourceLocWriter> _sourceLocWriter; RIFF::Builder _riff; - Encoder _encoder; + RIFF::BuildCursor _cursor; public: ModuleEncodingContext(SerialContainerUtil::WriteOptions const& options, Stream* stream) @@ -37,12 +37,12 @@ public: _sourceLocWriter = new SerialSourceLocWriter(options.sourceManager); } - _encoder = Encoder(_riff); + _cursor = RIFF::BuildCursor(_riff); } ~ModuleEncodingContext() { - _encoder = Encoder(_riff.getRootChunk()); + _cursor = RIFF::BuildCursor(_riff.getRootChunk()); encodeFinalPieces(); _riff.writeTo(_stream); } @@ -53,7 +53,7 @@ public: // is simply a matter of encoding the module for each // of the translation units that got compiled. // - Encoder::WithKeyValuePair withArray(&_encoder, SerialBinary::kModuleListFourCc); + SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, SerialBinary::kModuleListFourCc); for (TranslationUnitRequest* translationUnit : frontEndReq->translationUnits) { SLANG_RETURN_ON_FAIL(encode(translationUnit->module)); @@ -63,14 +63,14 @@ public: SlangResult encode(FrontEndCompileRequest* frontEndReq) { - Encoder::WithObject withObject(&_encoder, SerialBinary::kContainerFourCc); + SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, SerialBinary::kContainerFourCc); SLANG_RETURN_ON_FAIL(encodeModuleList(frontEndReq)); return SLANG_OK; } SlangResult encode(EndToEndCompileRequest* request) { - Encoder::WithObject withObject(&_encoder, SerialBinary::kContainerFourCc); + SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, SerialBinary::kContainerFourCc); // Encoding an end-to-end compile request starts with the same // work as for a front-end request: we encode each of @@ -99,7 +99,7 @@ public: auto sink = request->getSink(); auto program = request->getSpecializedGlobalAndEntryPointsComponentType(); { - Encoder::WithArray withArray(&_encoder); + SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, SerialBinary::kArrayFourCC); for (auto target : linkage->targets) { @@ -112,7 +112,7 @@ public: // and we need to encode information about each of them. // { - Encoder::WithArray withArray(&_encoder, SerialBinary::kEntryPointListFourCc); + SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, SerialBinary::kEntryPointListFourCc); auto entryPointCount = program->getEntryPointCount(); for (Index ii = 0; ii < entryPointCount; ++ii) @@ -142,34 +142,41 @@ public: SLANG_RETURN_ON_FAIL( writer.write(irModule, _sourceLocWriter, _options.optionFlags, &serialData)); - SLANG_RETURN_ON_FAIL(IRSerialWriter::writeTo(serialData, _encoder)); + SLANG_RETURN_ON_FAIL(IRSerialWriter::writeTo(serialData, _cursor)); return SLANG_OK; } - void encode(Name* name) { _encoder.encode(name->text); } + void encodeData(void const* data, size_t size, FourCC type) + { + _cursor.addDataChunk(type, data, size); + } - void encode(String const& value) { _encoder.encode(value); } + void encode(String const& value, FourCC type = SerialBinary::kStringFourCC) + { + encodeData(value.getBuffer(), value.getLength(), type); + } - void encode(uint32_t value) { _encoder.encode(UInt(value)); } + void encode(Name* name, FourCC type = SerialBinary::kNameFourCC) { encode(name->text, type); } - void encodeData(void const* data, size_t size) { _encoder.encodeData(data, size); } + + void encode(uint32_t value, FourCC type = SerialBinary::kUInt32FourCC) + { + encodeData(&value, sizeof(value), type); + } SlangResult encode(EntryPoint* entryPoint, String const& entryPointMangledName) { - Encoder::WithObject withObject(&_encoder, SerialBinary::kEntryPointFourCc); + SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, SerialBinary::kEntryPointFourCc); { - Encoder::WithObject withProperty(&_encoder, SerialBinary::kNameFourCC); - encode(entryPoint->getName()); + encode(entryPoint->getName(), SerialBinary::kNameFourCC); } { - Encoder::WithObject withProperty(&_encoder, SerialBinary::kProfileFourCC); - encode(entryPoint->getProfile().raw); + encode(entryPoint->getProfile().raw, SerialBinary::kProfileFourCC); } { - Encoder::WithObject withProperty(&_encoder, SerialBinary::kMangledNameFourCC); - encode(entryPointMangledName); + encode(entryPointMangledName, SerialBinary::kMangledNameFourCC); } return SLANG_OK; @@ -181,7 +188,7 @@ public: if (!(_options.optionFlags & (SerialOptionFlag::IRModule | SerialOptionFlag::ASTModule))) return SLANG_OK; - Encoder::WithObject withModule(&_encoder, SerialBinary::kModuleFourCC); + SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, SerialBinary::kModuleFourCC); // The first piece that we write for a module is its header. // The header is intended to provide information that can be @@ -194,15 +201,14 @@ public: // sense to serialize it separately from all the rest. // { - Encoder::WithObject withProperty(&_encoder, SerialBinary::kNameFourCC); - _encoder.encodeString(module->getNameObj()->text); + encode(module->getNameObj(), SerialBinary::kNameFourCC); } // The header includes a digest of all the compile options and // the files that the compiled result depended on. // auto digest = module->computeDigest(); - _encoder.encodeData(PropertyKeys<Module>::Digest, digest.data, sizeof(digest.data)); + _cursor.addDataChunk(PropertyKeys<Module>::Digest, digest.data, sizeof(digest.data)); // The header includes an array of the paths of all of the // files that the compiled result depended on. @@ -221,7 +227,7 @@ public: IRSerialWriter writer; SLANG_RETURN_ON_FAIL( writer.write(irModule, _sourceLocWriter, _options.optionFlags, &serialData)); - SLANG_RETURN_ON_FAIL(IRSerialWriter::writeTo(serialData, _encoder)); + SLANG_RETURN_ON_FAIL(IRSerialWriter::writeTo(serialData, _cursor)); } } @@ -232,9 +238,8 @@ public: { if (auto moduleDecl = module->getModuleDecl()) { - Encoder::WithKeyValuePair withKey(&_encoder, PropertyKeys<Module>::ASTModule); - - writeSerializedModuleAST(&_encoder, moduleDecl, _sourceLocWriter); + SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, PropertyKeys<Module>::ASTModule); + writeSerializedModuleAST(_cursor, moduleDecl, _sourceLocWriter); } } @@ -243,7 +248,7 @@ public: SlangResult encodeModuleDependencyPaths(Module* module) { - Encoder::WithObject withProperty(&_encoder, PropertyKeys<Module>::FileDependencies); + SLANG_SCOPED_RIFF_BUILDER_LIST_CHUNK(_cursor, PropertyKeys<Module>::FileDependencies); // TODO(tfoley): This is some of the most complicated logic // in the encoding system, because it tries to translate @@ -317,7 +322,6 @@ public: } Path::getCanonical(linkageRoot, linkageRoot); - Encoder::WithArray withArray(&_encoder); for (auto file : fileDependencies) { if (file->getPathInfo().hasFoundPath()) @@ -334,26 +338,26 @@ public: auto relativeModulePath = Path::getRelativePath(linkageRoot, canonicalModulePath); - _encoder.encodeString(relativeModulePath); + encode(relativeModulePath); } else { // For all other dependnet files, store them as relative paths with respect // to the module's path. canonicalFilePath = Path::getRelativePath(moduleDir, canonicalFilePath); - _encoder.encodeString(canonicalFilePath); + encode(canonicalFilePath); } } else { // If the module is coming from string instead of an actual file, store it as // is. - _encoder.encodeString(canonicalModulePath); + encode(canonicalModulePath); } } else { - _encoder.encodeString(file->getPathInfo().getMostUniqueIdentity()); + encode(file->getPathInfo().getMostUniqueIdentity()); } } @@ -369,18 +373,19 @@ public: SerialSourceLocData debugData; _sourceLocWriter->write(&debugData); - debugData.writeTo(_encoder); + debugData.writeTo(_cursor); } // Write the container string table if (_containerStringPool.getAdded().getCount() > 0) { - Encoder::WithKeyValuePair withKey(&_encoder, SerialBinary::kStringTableFourCc); - List<char> encodedTable; SerialStringTableUtil::encodeStringTable(_containerStringPool, encodedTable); - _encoder.encodeData(encodedTable.getBuffer(), encodedTable.getCount()); + _cursor.addDataChunk( + SerialBinary::kStringTableFourCc, + encodedTable.getBuffer(), + encodedTable.getCount()); } return SLANG_OK; @@ -425,14 +430,15 @@ public: String StringChunk::getValue() const { - return Decoder(this).decodeString(); + return String(UnownedStringSlice((char const*)getPayload(), getPayloadSize())); } RIFF::ChunkList<StringChunk> ModuleChunk::getFileDependencies() const { - Decoder decoder(this); - Decoder::WithProperty withProperty(decoder, PropertyKeys<Module>::FileDependencies); - return as<RIFF::ListChunk>(decoder.getCurrentChunk())->getChildren().cast<StringChunk>(); + auto found = findListChunk(PropertyKeys<Module>::FileDependencies); + if (!found) + return RIFF::ChunkList<StringChunk>(); + return found->getChildren().cast<StringChunk>(); } ModuleChunk const* ModuleChunk::find(RIFF::ListChunk const* baseChunk) @@ -453,14 +459,12 @@ SHA1::Digest ModuleChunk::getDigest() const String ModuleChunk::getName() const { - // TODO(tfoley): This kind of logic needs a way - // to be greatly simplified, so that we don't - // have to express such complicated logic for - // simply extracting a single string property... - // - Decoder decoder(this); - Decoder::WithProperty withProperty(decoder, SerialBinary::kNameFourCC); - return decoder.decodeString(); + auto found = findDataChunk(SerialBinary::kNameFourCC); + if (!found) + { + SLANG_UNEXPECTED("module chunk had no name"); + } + return static_cast<StringChunk const*>(found)->getValue(); } @@ -506,41 +510,32 @@ RIFF::ChunkList<EntryPointChunk> ContainerChunk::getEntryPoints() const String EntryPointChunk::getMangledName() const { - // TODO(tfoley): This kind of logic needs a way - // to be greatly simplified, so that we don't - // have to express such complicated logic for - // simply extracting a single string property... - // - Decoder decoder(this); - Decoder::WithProperty withProperty(decoder, SerialBinary::kMangledNameFourCC); - return decoder.decodeString(); + auto found = findDataChunk(SerialBinary::kMangledNameFourCC); + if (!found) + { + SLANG_UNEXPECTED("entry point chunk had no mangled name"); + } + return static_cast<StringChunk const*>(found)->getValue(); } String EntryPointChunk::getName() const { - // TODO(tfoley): This kind of logic needs a way - // to be greatly simplified, so that we don't - // have to express such complicated logic for - // simply extracting a single string property... - // - Decoder decoder(this); - Decoder::WithProperty withProperty(decoder, SerialBinary::kNameFourCC); - return decoder.decodeString(); + auto found = findDataChunk(SerialBinary::kNameFourCC); + if (!found) + { + SLANG_UNEXPECTED("entry point chunk had no name"); + } + return static_cast<StringChunk const*>(found)->getValue(); } Profile EntryPointChunk::getProfile() const { - // TODO(tfoley): This kind of logic needs a way - // to be greatly simplified, so that we don't - // have to express such complicated logic for - // simply extracting a single string property... - // - Decoder decoder(this); - Decoder::WithProperty withProperty(decoder, SerialBinary::kProfileFourCC); - - Profile::RawVal rawVal; - decoder.decode(rawVal); - + auto found = findDataChunk(SerialBinary::kProfileFourCC); + if (!found) + { + SLANG_UNEXPECTED("entry point chunk had no profile"); + } + auto rawVal = found->readPayloadAs<Profile::RawVal>(); return Profile(rawVal); } diff --git a/source/slang/slang-serialize-riff.cpp b/source/slang/slang-serialize-riff.cpp new file mode 100644 index 000000000..01b39e825 --- /dev/null +++ b/source/slang/slang-serialize-riff.cpp @@ -0,0 +1,897 @@ +// slang-serialize-riff.cpp +#include "slang-serialize-riff.h" + +namespace Slang +{ + +// +// RIFFSerialWriter +// + +RIFFSerialWriter::RIFFSerialWriter(RIFF::ChunkBuilder* chunk, FourCC type) + : _cursor(chunk) +{ + _initialize(type); +} + +RIFFSerialWriter::RIFFSerialWriter(RIFF::Builder& riff, FourCC type) + : _cursor(riff) +{ + _initialize(type); +} + +RIFFSerialWriter::~RIFFSerialWriter() +{ + // We need to flush any pending operations to + // write objects into the object definition list chunk. + // + _flush(); +} + +SerializationMode RIFFSerialWriter::getMode() +{ + return SerializationMode::Write; +} + +void RIFFSerialWriter::handleBool(bool& value) +{ + _cursor.addDataChunk(value ? RIFFSerial::kTrueFourCC : RIFFSerial::kFalseFourCC, nullptr, 0); +} + +void RIFFSerialWriter::handleInt8(int8_t& value) +{ + _writeInt(value); +} + +void RIFFSerialWriter::handleInt16(int16_t& value) +{ + _writeInt(value); +} + +void RIFFSerialWriter::handleInt32(Int32& value) +{ + _writeInt(value); +} + +void RIFFSerialWriter::handleInt64(Int64& value) +{ + _writeInt(value); +} + +void RIFFSerialWriter::handleUInt8(uint8_t& value) +{ + _writeUInt(value); +} + +void RIFFSerialWriter::handleUInt16(uint16_t& value) +{ + _writeUInt(value); +} + +void RIFFSerialWriter::handleUInt32(UInt32& value) +{ + _writeUInt(value); +} + +void RIFFSerialWriter::handleUInt64(UInt64& value) +{ + _writeUInt(value); +} + +void RIFFSerialWriter::handleFloat32(float& value) +{ + _writeFloat(value); +} + +void RIFFSerialWriter::handleFloat64(double& value) +{ + _writeFloat(value); +} + +void RIFFSerialWriter::handleString(String& value) +{ + _cursor.addDataChunk(RIFFSerial::kStringFourCC, value.getBuffer(), value.getLength()); +} + +void RIFFSerialWriter::_writeInt(Int64 value) +{ + // We pick a 32-bit representation if it can + // faithfully represent the value, and a 64-bit + // representation otherwise. + // + if (Int32(value) == value) + { + auto v = Int32(value); + _cursor.addDataChunk(RIFFSerial::kInt32FourCC, &v, sizeof(v)); + } + else + { + _cursor.addDataChunk(RIFFSerial::kInt64FourCC, &value, sizeof(value)); + } +} + +void RIFFSerialWriter::_writeUInt(UInt64 value) +{ + // We pick a 32-bit representation if it can + // faithfully represent the value, and a 64-bit + // representation otherwise. + // + if (UInt32(value) == value) + { + auto v = UInt32(value); + _cursor.addDataChunk(RIFFSerial::kUInt32FourCC, &v, sizeof(v)); + } + else + { + _cursor.addDataChunk(RIFFSerial::kUInt64FourCC, &value, sizeof(value)); + } +} + +void RIFFSerialWriter::_writeFloat(double value) +{ + // We pick a 32-bit representation if it can + // faithfully represent the value, and a 64-bit + // representation otherwise. + // + if (float(value) == value) + { + auto v = float(value); + _cursor.addDataChunk(RIFFSerial::kFloat32FourCC, &v, sizeof(v)); + } + else + { + _cursor.addDataChunk(RIFFSerial::kFloat64FourCC, &value, sizeof(value)); + } +} + +void RIFFSerialWriter::beginArray() +{ + _cursor.beginListChunk(RIFFSerial::kArrayFourCC); +} + +void RIFFSerialWriter::endArray() +{ + _cursor.endChunk(); +} + +void RIFFSerialWriter::beginDictionary() +{ + _cursor.beginListChunk(RIFFSerial::kDictionaryFourCC); +} + +void RIFFSerialWriter::endDictionary() +{ + _cursor.endChunk(); +} + +bool RIFFSerialWriter::hasElements() +{ + return false; +} + +void RIFFSerialWriter::beginStruct() +{ + _cursor.beginListChunk(RIFFSerial::kStructFourCC); +} + +void RIFFSerialWriter::handleFieldKey(char const* name, Int index) +{ + // For now we are ignoring field keys, and treating + // structs as basically equivalent to tuples. + SLANG_UNUSED(name); + SLANG_UNUSED(index); +} + +void RIFFSerialWriter::endStruct() +{ + _cursor.endChunk(); +} + +void RIFFSerialWriter::beginTuple() +{ + _cursor.beginListChunk(RIFFSerial::kTupleFourCC); +} + +void RIFFSerialWriter::endTuple() +{ + _cursor.endChunk(); +} + +void RIFFSerialWriter::beginOptional() +{ + _cursor.beginListChunk(RIFFSerial::kOptionalFourCC); +} + +void RIFFSerialWriter::endOptional() +{ + _cursor.endChunk(); +} + +void RIFFSerialWriter::handleSharedPtr(void*& value, Callback callback, void* userData) +{ + // Because we are writing, we only care about the + // pointer that is already present in `value`. + // + void* ptr = value; + + // The first special case we check for is a null pointer, + // which we can serialize as an inline value. + // + if (ptr == nullptr) + { + _cursor.addDataChunk(RIFFSerial::kNullFourCC, nullptr, 0); + return; + } + + // Next, we check to see if we have encountered this + // pointer before, in which case we've already allocated + // an index for it in the object definition list, and + // we can simply write a reference to that index. + // + if (auto found = _mapPtrToObjectIndex.tryGetValue(ptr)) + { + auto objectIndex = *found; + _writeObjectReference(objectIndex); + return; + } + + // If we have a non-null pointer that we haven't seen + // before, then we will allocate a new entry in the + // object definition list, and the pointer itself + // will be written as a reference to that entry. + // + auto objectIndex = ObjectIndex(_objects.getCount()); + _mapPtrToObjectIndex.add(ptr, objectIndex); + _writeObjectReference(objectIndex); + + // At this point we've correctly written the *reference* + // to the object (and will be able to write further + // references later if we see an identical pointer), + // but we also need to make sure that the *definition* + // of the object gets written into the object definition + // list chunk. + // + // The `callback` that was passed in can be used to + // write out the members of the object, but if we + // simply invoked it here and now we would be at risk + // of introducing unbounded recursion in cases where + // the object graph contains very long pointer chains. + // + // (Note that we are not at risk of *infinite* recursion, + // because we have already cached the index for the + // object into `_mapPtrToObjectIndex`) + // + // We will simply add an entry to our `_objects` array + // to represent the to-be-written object, and store + // the pointer and callback there so that we can write + // everything out later, in `_flush()`. + // + ObjectInfo objectInfo; + objectInfo.ptr = ptr; + objectInfo.callback = callback; + objectInfo.userData = userData; + _objects.add(objectInfo); +} + +void RIFFSerialWriter::handleUniquePtr(void*& value, Callback callback, void* userData) +{ + // We treat all pointers as shared pointers, because there isn't really + // an optimized representation we would want to use for the unique case. + // + handleSharedPtr(value, callback, userData); +} + +void RIFFSerialWriter::handleDeferredObjectContents( + void* valuePtr, + Callback callback, + void* userData) +{ + // Because we are already deferring writing of the *entirety* of + // an object's members as part of how `handleSharedPtr()` works, + // we don't need to implement deferral at this juncture. + // + // (In practice the `handleDeferredObjectContents()` operation is + // more for the benefit of reading than writing). + // + callback(valuePtr, userData); +} + +void RIFFSerialWriter::_writeObjectReference(ObjectIndex index) +{ + _cursor.addDataChunk(RIFFSerial::kObjectReferenceFourCC, &index, sizeof(index)); +} + +void RIFFSerialWriter::_initialize(FourCC type) +{ + // The entire content that we write will be nested + // in a single list chunk, with the type that was + // passed in. + // + _cursor.beginListChunk(type); + + // The first child chunk needs to be the object + // definition list chunk, so we create it up front. + // + _objectDefinitionListChunk = _cursor.addListChunk(RIFFSerial::kObjectDefinitionListFourCC); +} + +void RIFFSerialWriter::_flush() +{ + // At this point we might have zero or more object + // waiting to be written into the object definition list + // chunk, and we need to make sure that they all + // get a chance to write their content out. + // + _cursor.setCurrentChunk(_objectDefinitionListChunk); + + // Note that we do *not* compute `_objects.getCount()` outside + // of the loop here, because writing out one object definition + // could cause other objects to be referenced, which could + // in turn add more entries to `_objects` that need to be + // written out. + // + while (_writtenObjectDefinitionCount < _objects.getCount()) + { + auto objectIndex = _writtenObjectDefinitionCount++; + auto objectInfo = _objects[objectIndex]; + + // We shouldn't ever be putting a null pointer into the + // object definition list; there is logic in `handleSharedPtr()` + // that explicitly checks for a null pointer and does an + // early-exit in that case. + // + SLANG_ASSERT(objectInfo.ptr); + + // The callback that was passed into `handleSharedPtr()` should + // be able to write out the value of the pointed-to object. + // + // Note that we are passing the *address* of `objectInfo.ptr` + // and not just its *value*, because this callback is used + // for both reading and writing, and in the reading case it + // needs to be invoked on a pointer-pointer (e.g., a `T**` when + // serializing an object pointer `T*`) so that the callee + // can set the pointed-to pointer to whatever object it + // allocates or finds. + // + objectInfo.callback(&objectInfo.ptr, objectInfo.userData); + + // TODO(tfoley): There is an important invariant here that + // the callback had better only write *one* value, but + // that is not currently being enforced. + } +} + +// +// RIFFSerialReader +// + +RIFFSerialReader::RIFFSerialReader(RIFF::Chunk const* chunk, FourCC type) + : _cursor(chunk) +{ + _initialize(type); +} + +RIFFSerialReader::~RIFFSerialReader() +{ + _flush(); +} + +SerializationMode RIFFSerialReader::getMode() +{ + return SerializationMode::Read; +} + +void RIFFSerialReader::handleBool(bool& value) +{ + switch (_peekChunkType()) + { + case RIFFSerial::kTrueFourCC: + _advanceCursor(); + value = true; + break; + + case RIFFSerial::kFalseFourCC: + _advanceCursor(); + value = false; + break; + + default: + SLANG_UNEXPECTED("invalid format in RIFF"); + break; + } +} + +void RIFFSerialReader::handleInt8(int8_t& value) +{ + value = int8_t(_readInt()); +} + +void RIFFSerialReader::handleInt16(int16_t& value) +{ + value = int16_t(_readInt()); +} + +void RIFFSerialReader::handleInt32(Int32& value) +{ + value = Int32(_readInt()); +} + +void RIFFSerialReader::handleInt64(Int64& value) +{ + value = Int64(_readInt()); +} + +void RIFFSerialReader::handleUInt8(uint8_t& value) +{ + value = uint8_t(_readUInt()); +} + +void RIFFSerialReader::handleUInt16(uint16_t& value) +{ + value = uint16_t(_readUInt()); +} + +void RIFFSerialReader::handleUInt32(UInt32& value) +{ + value = UInt32(_readUInt()); +} + +void RIFFSerialReader::handleUInt64(UInt64& value) +{ + value = UInt64(_readUInt()); +} + +void RIFFSerialReader::handleFloat32(float& value) +{ + value = float(_readFloat()); +} + +void RIFFSerialReader::handleFloat64(double& value) +{ + value = double(_readFloat()); +} + +void RIFFSerialReader::handleString(String& value) +{ + if (_peekChunkType() != RIFFSerial::kStringFourCC) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + return; + } + + auto dataChunk = as<RIFF::DataChunk>(_cursor); + if (!dataChunk) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + return; + } + + auto size = dataChunk->getPayloadSize(); + + value = String(); + value.appendRepeatedChar(' ', size); + dataChunk->writePayloadInto((char*)value.getBuffer(), size); + + _advanceCursor(); +} + +void RIFFSerialReader::beginArray() +{ + _beginListChunk(RIFFSerial::kArrayFourCC); +} + +void RIFFSerialReader::endArray() +{ + _endListChunk(); +} + + +void RIFFSerialReader::beginDictionary() +{ + _beginListChunk(RIFFSerial::kDictionaryFourCC); +} + +void RIFFSerialReader::endDictionary() +{ + _endListChunk(); +} + +bool RIFFSerialReader::hasElements() +{ + return _cursor.get() != nullptr; +} + +void RIFFSerialReader::beginStruct() +{ + _beginListChunk(RIFFSerial::kStructFourCC); +} + +void RIFFSerialReader::handleFieldKey(char const* name, Int index) +{ + // For now we are ignoring field keys, and treating + // structs as basically equivalent to tuples. + SLANG_UNUSED(name); + SLANG_UNUSED(index); +} + +void RIFFSerialReader::endStruct() +{ + _endListChunk(); +} + +void RIFFSerialReader::beginTuple() +{ + _beginListChunk(RIFFSerial::kTupleFourCC); +} + +void RIFFSerialReader::endTuple() +{ + _endListChunk(); +} + +void RIFFSerialReader::beginOptional() +{ + _beginListChunk(RIFFSerial::kOptionalFourCC); +} + +void RIFFSerialReader::endOptional() +{ + _endListChunk(); +} + +RIFFSerialReader::ObjectIndex RIFFSerialReader::_readObjectReference() +{ + if (_peekChunkType() != RIFFSerial::kObjectReferenceFourCC) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + UNREACHABLE_RETURN(false); + } + + auto objectIndex = _readDataChunk<ObjectIndex>(); + SLANG_ASSERT(objectIndex >= 0 && objectIndex < _objects.getCount()); + return objectIndex; +} + +void RIFFSerialReader::handleSharedPtr(void*& value, Callback callback, void* userData) +{ + // The logic here largely mirrors what appears in + // `RIFFSerialWriter::handleSharedPtr`. + // + // We first check for an explicitly written null pointer. + // If we find one our work is very easy. + // + if (_peekChunkType() == RIFFSerial::kNullFourCC) + { + _advanceCursor(); + value = nullptr; + return; + } + + // Otherwise, we expect to find a reference to + // an object index. + // + // Note that `_readObjectReference()` already asserts + // that the index is in-bounds, so we don't repeat + // that test here. + // + auto objectIndex = _readObjectReference(); + + // Now we need to check if we've previously read in + // a reference to the same object. + // + auto& objectInfo = _objects[objectIndex]; + if (objectInfo.state != ObjectState::Unread) + { + // We've seen this object before, although it + // is still possible that we are in the middle + // of reading it as part of an invocation + // of `handleSharedPtr()` further up the call + // stack. + // + // If a non-nullpointer value has already been + // written into the `objectInfo`, then that means + // the callback that was run for the prior (or + // in-flight) read operation has already allocated + // or found an object and written it out. + // In that case we will trust the value. + // + if (objectInfo.ptr == nullptr) + { + // It is possible that the pointer is null because + // the callback that was invoked explicitly *chose* + // to yield a null pointer (perhaps the application + // is choosing not to deserialize some optional + // piece of state). + // + // However, if there is still a callback in-flight + // to read this object, and the pointer is null, + // then we have reached a circular reference, + // and need to signal an error. + // + if (objectInfo.state == ObjectState::ReadingInProgress) + { + SLANG_UNEXPECTED("circularity detected in RIFF deserialization"); + } + } + value = objectInfo.ptr; + return; + } + + // At this point we are reading a reference to an + // object index that has not yet been read at all. + // + SLANG_ASSERT(objectInfo.state == ObjectState::Unread); + + // We cannot return from this function until we have + // stored a pointer into `value`, to represent the + // deserialized object. + // + // Thus we will set ourselves up to start reading + // from the relevant object definition, and invoke + // the callback that was passed in. + // + // Calling into user-defined serialization logic from + // within this function creates the possibility of + // unbounded/infinite recursion, so it is vital that + // the user is properly using `deferSerializeObjectContents()` + // to delay reading data that isn't immediately + // necessary. + // + // We will still set the `objectInfo.state` to reflect + // this in-flight operation so that we can detect + // a cirularity if one occurs at runtime. + // + objectInfo.state = ObjectState::ReadingInProgress; + + // We save/restore the current cursor around + // the callback, because we need to be able + // to return to the current state to continue + // reading whatever comes after the pointer + // we were invoked to read. + // + _pushCursor(); + _cursor = objectInfo.definitionChunk; + + // Note that we are passing the address of `objectInfo.ptr`, + // and `objectInfo` is a reference to an element of the + // `_objects` array. Thus whenever the `callback` stores + // a pointer into that output parameter, the value it writes + // will automatically be visible to any subsequent calls + // to `handleSharedPtr()`, even if they occur before + // `callback` returns. + // + // Thus a "true" circularity can only occur if the callback + // recursively reads a reference to the same object again + // *before* it allocates the in-memory representation of + // that objects and stores a pointer to it into the output + // parameter. + // + callback(&objectInfo.ptr, userData); + + _popCursor(); + + objectInfo.state = ObjectState::ReadingComplete; + + value = objectInfo.ptr; +} + +void RIFFSerialReader::handleUniquePtr(void*& value, Callback callback, void* userData) +{ + // We treat all pointers as shared pointers, because there isn't really + // an optimized representation we would want to use for the unique case. + // + handleSharedPtr(value, callback, userData); +} + +void RIFFSerialReader::handleDeferredObjectContents( + void* valuePtr, + Callback callback, + void* userData) +{ + // Unlike the case in `RIFFSerialWriter::handleDeferredObjectContents()`, + // we very much *do* want to delay invoking the callback until later. + // + // There is a kind of symmetry going on, where the writer delays the + // callback passed to `handleSharedPtr()`, but *not* the callback + // passed to `handleDeferredObjectContents()`, while the reader + // does the opposite: immediately calls the callback in `handleSharedPtr()` + // but delays calling it here. + + // We make sure to save the current `_cursor` value along with + // the arguments that will be passed into the callback, so that + // we can restore the reader to this state before invoking + // the callbak in `_flush()`. + + DeferredAction deferredAction; + deferredAction.savedCursor = _cursor; + deferredAction.valuePtr = valuePtr; + deferredAction.callback = callback; + deferredAction.userData = userData; + + _deferredActions.add(deferredAction); +} + +void RIFFSerialReader::_initialize(FourCC type) +{ + // All of the content will have been serialized as a single RIFF + // list chunk (possibly a root chunk if this content comprises + // an entire file), with the given `type`. + // + _beginListChunk(type); + + // The first child chunk should be the object definition list + // chunk, and we will proactively read through all of the + // entries in that chunk to build up the `_objects` array. + // + // This operation takes linear time in the number of serialized + // objects, independent of their size, because the RIFF chunk + // headers allow us to skip over the content of each of these + // object-definition chunks. + // + _beginListChunk(RIFFSerial::kObjectDefinitionListFourCC); + while (auto objectDefinitionChunk = _cursor.get()) + { + ObjectInfo objectInfo; + objectInfo.definitionChunk = objectDefinitionChunk; + _objects.add(objectInfo); + + _advanceCursor(); + } + _endListChunk(); +} + +void RIFFSerialReader::_flush() +{ + // We need to flush any actions that were deferred + // and are still pending. + // + while (_deferredActions.getCount() != 0) + { + // TODO: For simplicity we are using the `_deferredActions` + // array as a stack (LIFO), but it would be good to + // check whether there is a menaingful difference in how + // large the array would need to grow for a FIFO vs. LIFO, + // and pick the better option. + // + auto deferredAction = _deferredActions.getLast(); + _deferredActions.removeLast(); + + _cursor = deferredAction.savedCursor; + deferredAction.callback(deferredAction.valuePtr, deferredAction.userData); + } +} + +FourCC RIFFSerialReader::_peekChunkType() +{ + auto chunk = _cursor.get(); + if (!chunk) + return 0; + return chunk->getType(); +} + +Int64 RIFFSerialReader::_readInt() +{ + switch (_peekChunkType()) + { + case RIFFSerial::kInt64FourCC: + return _readDataChunk<Int64>(); + case RIFFSerial::kInt32FourCC: + return _readDataChunk<Int32>(); + + case RIFFSerial::kUInt32FourCC: + return _readDataChunk<UInt32>(); + + case RIFFSerial::kUInt64FourCC: + { + auto uintValue = _readDataChunk<UInt64>(); + if (Int64(uintValue) < 0) + { + SLANG_UNEXPECTED("signed/unsigned mismatch in RIFF"); + } + return Int64(uintValue); + } + + default: + SLANG_UNEXPECTED("invalid format in RIFF"); + UNREACHABLE_RETURN(0); + } +} + +UInt64 RIFFSerialReader::_readUInt() +{ + switch (_peekChunkType()) + { + case RIFFSerial::kUInt64FourCC: + return _readDataChunk<UInt64>(); + case RIFFSerial::kUInt32FourCC: + return _readDataChunk<UInt32>(); + + case RIFFSerial::kInt32FourCC: + case RIFFSerial::kInt64FourCC: + { + auto intValue = _readInt(); + if (intValue < 0) + { + SLANG_UNEXPECTED("signed/unsigned mismatch in RIFF"); + } + return UInt64(intValue); + } + + default: + SLANG_UNEXPECTED("invalid format in RIFF"); + UNREACHABLE_RETURN(0); + } +} + +double RIFFSerialReader::_readFloat() +{ + switch (_peekChunkType()) + { + case RIFFSerial::kFloat32FourCC: + return _readDataChunk<float>(); + case RIFFSerial::kFloat64FourCC: + return _readDataChunk<double>(); + + default: + SLANG_UNEXPECTED("invalid format in RIFF"); + UNREACHABLE_RETURN(0); + } +} + +void RIFFSerialReader::_readDataChunk(void* outData, size_t dataSize) +{ + auto dataChunk = as<RIFF::DataChunk>(_cursor); + if (!dataChunk) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + return; + } + auto size = dataChunk->getPayloadSize(); + if (size < dataSize) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + return; + } + dataChunk->writePayloadInto(outData, dataSize); + _advanceCursor(); +} + + +void RIFFSerialReader::_beginListChunk(FourCC type) +{ + auto listChunk = as<RIFF::ListChunk>(_cursor); + if (!listChunk || listChunk->getType() != type) + { + SLANG_UNEXPECTED("invalid format in RIFF"); + } + + _advanceCursor(); + _pushCursor(); + + _cursor = listChunk->getFirstChild(); +} + +void RIFFSerialReader::_endListChunk() +{ + _popCursor(); +} + +void RIFFSerialReader::_advanceCursor() +{ + _cursor = _cursor.getNextSibling(); +} + +void RIFFSerialReader::_pushCursor() +{ + _stack.add(_cursor); +} + +void RIFFSerialReader::_popCursor() +{ + SLANG_ASSERT(_stack.getCount() != 0); + _cursor = _stack.getLast(); + _stack.removeLast(); +} + + +} // namespace Slang diff --git a/source/slang/slang-serialize-riff.h b/source/slang/slang-serialize-riff.h new file mode 100644 index 000000000..a464a5ded --- /dev/null +++ b/source/slang/slang-serialize-riff.h @@ -0,0 +1,431 @@ +// slang-serialize-riff.h +#ifndef SLANG_SERIALIZE_RIFF_H +#define SLANG_SERIALIZE_RIFF_H + +// +// This file provides implementations of `ISerializerImpl` that +// serialize hierarchical data in a RIFF-based format. +// +// This implementation can be seen as an adapter between the +// `Slang::Serializer` and `Slang::RIFF` subsystems, and also +// serves an an example of how to write a complete reader/writer +// pair for a new serialization format. +// + +#include "../core/slang-riff.h" +#include "slang-serialize.h" + +namespace Slang +{ + +namespace RIFFSerial +{ +// +// Each value in the hierarchy will be ended as a RIFF chunk. +// The type of the chunk, and whether it is a list or data +// chunk will depend on the kind of value. +// +// All of the serialized data will be encapsulated in a single +// list chunk. This chunk can have its `FourCC` customized, +// but a default is also provided. +// + +/// Default type for root chunk of a serialized object graph. +static const FourCC::RawValue kRootFourCC = SLANG_FOUR_CC('r', 'o', 'o', 't'); + +// +// Simple numeric values are stored as data chunks. +// Rather than go down to the granularity of 16- and +// 8-bit integers, we stick to 32- and 64-bit values +// only, since the overhead of a RIFF chunk header +// is already 64 bits (so the savings would be +// minimal). +// + +static const FourCC::RawValue kInt32FourCC = SLANG_FOUR_CC('i', '3', '2', ' '); +static const FourCC::RawValue kInt64FourCC = SLANG_FOUR_CC('i', '6', '4', ' '); + +static const FourCC::RawValue kUInt32FourCC = SLANG_FOUR_CC('u', '3', '2', ' '); +static const FourCC::RawValue kUInt64FourCC = SLANG_FOUR_CC('u', '6', '4', ' '); + +static const FourCC::RawValue kFloat32FourCC = SLANG_FOUR_CC('f', '3', '2', ' '); +static const FourCC::RawValue kFloat64FourCC = SLANG_FOUR_CC('f', '6', '4', ' '); + +// +// Boolean values are stored as empty chunks, with a unique +// type tag for each of the two possible values. +// + +static const FourCC::RawValue kTrueFourCC = SLANG_FOUR_CC('t', 'r', 'u', 'e'); +static const FourCC::RawValue kFalseFourCC = SLANG_FOUR_CC('f', 'a', 'l', 's'); + +// +// Strings are stored as a data chunk, with the payload of +// that chunk holding the bytes of the UTF-8 encoded string. +// The length of the string is stored as part of the chunk +// header. +// +// We also define a `FourCC` for raw data chunks, in anticipation +// of support for raw data being added to `ISerializerImpl` as +// an analogue of strings. +// + +static const FourCC::RawValue kStringFourCC = SLANG_FOUR_CC('s', 't', 'r', ' '); +static const FourCC::RawValue kDataFourCC = SLANG_FOUR_CC('d', 'a', 't', 'a'); + +// +// Containers (arrays, dictionaries, optionals, tuples, and structs) +// are stored as list chunks, with their elements as child chunks. +// + +static const FourCC::RawValue kArrayFourCC = SLANG_FOUR_CC('a', 'r', 'r', 'y'); +static const FourCC::RawValue kDictionaryFourCC = SLANG_FOUR_CC('d', 'i', 'c', 't'); +static const FourCC::RawValue kStructFourCC = SLANG_FOUR_CC('s', 't', 'r', 'c'); +static const FourCC::RawValue kTupleFourCC = SLANG_FOUR_CC('t', 'p', 'l', 'e'); +static const FourCC::RawValue kOptionalFourCC = SLANG_FOUR_CC('o', 'p', 't', '?'); + +// +// Null pointer values are simply stored as an empty data chunk with +// a distinct type. +// + +static const FourCC::RawValue kNullFourCC = SLANG_FOUR_CC('n', 'u', 'l', 'l'); + +// +// Non-null pointers are stored as a data chunk that references a +// serialized object by its `ObjectIndex`. +// + +using ObjectIndex = Int32; + +static const FourCC::RawValue kObjectReferenceFourCC = SLANG_FOUR_CC('o', 'b', 'j', 'r'); + +// +// All of the objects transitively referenced in the serialized object +// graph are stored in a list chunk of object definitions, with one +// chunk per object. The object definitions themselves are ordinary +// values using any of the cases above. +// + +static const FourCC::RawValue kObjectDefinitionListFourCC = SLANG_FOUR_CC('o', 'b', 'j', 's'); + +// +// The first child of the root chunk will be the object definition list +// chunk, and that will be followed by zero or more "root values" that +// have been serialized. +// + +} // namespace RIFFSerial + +/// Serializer implementation for writing to a tree of RIFF chunks. +struct RIFFSerialWriter : ISerializerImpl +{ +public: + /// Construct a writer to append to the given RIFF `chunk`. + /// + /// The object graph will be serialized into a child chunk + /// of `chunk`, as a list chunk with the given `type`. + /// + RIFFSerialWriter(RIFF::ChunkBuilder* chunk, FourCC type = RIFFSerial::kRootFourCC); + + + /// Construct a writer to write an entire RIFF file. + /// + /// The object graph will be serialized as the root chunk + /// of `riff`, with the given `type`. + /// + RIFFSerialWriter(RIFF::Builder& riff, FourCC type = RIFFSerial::kRootFourCC); + + /// Finalize writing. + /// + /// Any pending operations needed to write the entire object + /// graph will be flushed. + /// + ~RIFFSerialWriter(); + +private: + RIFFSerialWriter() = delete; + + /// Cursor for where in the RIFF hierarchy we are writing. + RIFF::BuildCursor _cursor; + + /// Representation of an index into the object list. + using ObjectIndex = RIFFSerial::ObjectIndex; + + /// Information about an object that should be + /// added to the object definition list. + struct ObjectInfo + { + /// Pointer to the in-memory C++ object. + void* ptr; + + /// Callback that can be invoked to serialize the object's data. + Callback callback; + + /// User-data pointer for `callback` + void* userData; + }; + + /// The chunk where object definitions are listed. + RIFF::ListChunkBuilder* _objectDefinitionListChunk = nullptr; + + /// Information on the objects that have been referenced, + /// and which need their definitions to be serialized into + /// the object definition list chunk. + /// + List<ObjectInfo> _objects; + Index _writtenObjectDefinitionCount = 0; + + /// Maps the address of an in-memory C++ object to the + /// corresponding entry in `_objects`, if any. + Dictionary<void*, ObjectIndex> _mapPtrToObjectIndex; + + + void _initialize(FourCC type); + void _flush(); + + void _writeInt(Int64 value); + void _writeUInt(UInt64 value); + void _writeFloat(double value); + + void _writeObjectReference(ObjectIndex index); + +private: + // + // The following declarations are the requirements + // of the `ISerializerImpl` interface: + // + + virtual SerializationMode getMode() override; + + virtual void handleBool(bool& value) override; + + virtual void handleInt8(int8_t& value) override; + virtual void handleInt16(int16_t& value) override; + virtual void handleInt32(Int32& value) override; + virtual void handleInt64(Int64& value) override; + + virtual void handleUInt8(uint8_t& value) override; + virtual void handleUInt16(uint16_t& value) override; + virtual void handleUInt32(UInt32& value) override; + virtual void handleUInt64(UInt64& value) override; + + virtual void handleFloat32(float& value) override; + virtual void handleFloat64(double& value) override; + + virtual void handleString(String& value) override; + + virtual void beginArray() override; + virtual void endArray() override; + + virtual void beginDictionary() override; + virtual void endDictionary() override; + + virtual bool hasElements() override; + + virtual void beginStruct() override; + virtual void handleFieldKey(char const* name, Int index) override; + virtual void endStruct() override; + + virtual void beginTuple() override; + virtual void endTuple() override; + + virtual void beginOptional() override; + virtual void endOptional() override; + + virtual void handleSharedPtr(void*& value, Callback callback, void* userData) override; + virtual void handleUniquePtr(void*& value, Callback callback, void* userData) override; + + virtual void handleDeferredObjectContents(void* valuePtr, Callback callback, void* userData) + override; +}; + +/// Serializer implementation for reading from a tree of RIFF chunks. +struct RIFFSerialReader : ISerializerImpl +{ +public: + /// Construct a reader to read data from the given `chunk`. + /// + /// Will validate that the given `chunk` is a list chunk + /// matching the expected `type`. + /// + RIFFSerialReader(RIFF::Chunk const* chunk, FourCC type = RIFFSerial::kRootFourCC); + + /// Finalize the reader. + /// + /// This will flush any outstanding operations that + /// might be pending. + /// + ~RIFFSerialReader(); + +private: + /// Representation of a read cursor in the serialized RIFF data. + using Cursor = RIFF::BoundsCheckedChunkPtr; + + /// Current cursor in the serialized RIFF data. + Cursor _cursor; + + void _advanceCursor(); + + /// A stack of saved cursors, reflecting the nesting + /// hierarchy of container chunks being read from. + /// + List<Cursor> _stack; + + void _pushCursor(); + void _popCursor(); + + /// Representation of an index into the object list. + using ObjectIndex = RIFFSerial::ObjectIndex; + + /// State of a serialized object that may or may not have been read already. + enum class ObjectState + { + Unread, + ReadingInProgress, + ReadingComplete, + }; + + /// Information about a serialized object in the object definition list. + struct ObjectInfo + { + /// State of the object. + /// + ObjectState state = ObjectState::Unread; + + /// Pointer to an in-memory C++ object representing the serialized object. + /// + /// Should only be accessed with consideration of what `state` is: + /// + /// * If `state` is `Unread`, then this should always be null. + /// + /// * If `state` is `ReadingComplete`, then this should be be + /// a valid pointer to the in-memory representation of the serialized object + /// (or null, if client code chose to deserialize it into a null pointer + /// for some reason). + /// + /// * If `state` is `ReadingInProgress`, then this might be a null pointer, + /// indicating that the logic to deserialize the object is currently + /// running (but has not yet allocated a representation and set it), or + /// it might be non-null indicating that the in-memory representation + /// has been allocated. + /// + /// Even if `ptr` is non-null, it may not be safe to access the + /// contents of the pointed-to object, because there may be deferred + /// operations pending to read some or all of its members. + /// + void* ptr = nullptr; + + /// The chunk that holds the definition of this object. + RIFF::Chunk const* definitionChunk = nullptr; + }; + + /// All of the objects from the object definition list chunk. + List<ObjectInfo> _objects; + + /// A serialization action that has been deferred. + /// + /// Deferred actions are typically used to put off recursively + /// reading all of the members of an object, thus avoiding + /// the potential for unbounded or even infinite recursion. + /// + struct DeferredAction + { + /// The in-memory object that the action should apply to. + void* valuePtr; + + /// The value of `_cursor` at the time this action was deferred. + Cursor savedCursor; + + /// The callback to apply to read data into the `valuePtr` + Callback callback; + + /// The user-data pointer for the `callback`. + void* userData; + }; + + /// Deferred actions that are still pending. + /// + /// As long as this array is non-empty, the contents of + /// in-memory objects read from the serialized data should + /// not be inspected/used. + /// + List<DeferredAction> _deferredActions; + + void _initialize(FourCC type); + void _flush(); + + FourCC _peekChunkType(); + + Int64 _readInt(); + UInt64 _readUInt(); + double _readFloat(); + + ObjectIndex _readObjectReference(); + + void _readDataChunk(void* outData, size_t dataSize); + + template<typename T> + T _readDataChunk() + { + T value; + _readDataChunk(&value, sizeof(value)); + return value; + } + + void _beginListChunk(FourCC type); + void _endListChunk(); + +private: + // + // The following declarations are the requirements + // of the `ISerializerImpl` interface: + // + + virtual SerializationMode getMode() override; + + virtual void handleBool(bool& value) override; + + virtual void handleInt8(int8_t& value) override; + virtual void handleInt16(int16_t& value) override; + virtual void handleInt32(Int32& value) override; + virtual void handleInt64(Int64& value) override; + + virtual void handleUInt8(uint8_t& value) override; + virtual void handleUInt16(uint16_t& value) override; + virtual void handleUInt32(UInt32& value) override; + virtual void handleUInt64(UInt64& value) override; + + virtual void handleFloat32(float& value) override; + virtual void handleFloat64(double& value) override; + + virtual void handleString(String& value) override; + + virtual void beginArray() override; + virtual void endArray() override; + + virtual void beginDictionary() override; + virtual void endDictionary() override; + + virtual bool hasElements() override; + + virtual void beginStruct() override; + virtual void handleFieldKey(char const* name, Int index) override; + virtual void endStruct() override; + + virtual void beginTuple() override; + virtual void endTuple() override; + + virtual void beginOptional() override; + virtual void endOptional() override; + + virtual void handleSharedPtr(void*& value, Callback callback, void* userData) override; + virtual void handleUniquePtr(void*& value, Callback callback, void* userData) override; + + virtual void handleDeferredObjectContents(void* valuePtr, Callback callback, void* userData) + override; +}; + +} // namespace Slang + +#endif diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h index 4b864fc94..d76ab8338 100644 --- a/source/slang/slang-serialize.h +++ b/source/slang/slang-serialize.h @@ -2,570 +2,978 @@ #ifndef SLANG_SERIALIZE_H #define SLANG_SERIALIZE_H -// #include <type_traits> - -#include "../compiler-core/slang-name.h" -#include "../core/slang-byte-encode-util.h" -#include "../core/slang-riff.h" -#include "../core/slang-stream.h" -#include "slang-serialize-types.h" +// This file defines an API for serialization. +// +// The API is intended to support multiple serialization formats, +// and to work with complicated object graphs that may include +// shared pointers, circular references, and so on. +// +// For anybody who don't want to dig into the details, the short +// version is that if you have a user-defined type like: +// +// // my-thing.h +// ... +// +// struct MyThing +// { +// float a; +// List<OtherThing> otherThings; +// SomeObject* object; +// }; +// +// then you can declare serialization support for your type +// with something like: +// +// // my-thing.h +// ... +// #include "slang-serialize.h" +// ... +// +// struct MyThing { ... } +// +// void serialize(Serializer const& serializer, MyThing& value); +// +// and then implement that support with something like: +// +// // my-thing.cpp +// #include "my-thing.h" +// +// ... +// +// void serialize(Serializer const& serializer, MyThing& value) +// { +// SLANG_SCOPED_SERIALIZER_STRUCT(serializer); +// serialize(serializer, value.a); +// serialize(serializer, value.otherThings); +// serialize(serializer, value.object); +// } +// +// That's it. So long as the `OtherThing` and `SomeObject` types used +// in the declaration of `MyType` already implemented serialization +// support, your new type should be fully serializable. +// + +#include "../core/slang-basic.h" + +#include <optional> namespace Slang { -class Linkage; - -/* -A discussion of the serialization system design can be found in - -docs/design/serialization.md -*/ +// +// A central design choice of this serialization system is that +// both reading and writing of serialized data for a type are +// implemented using a single function. This choice makes it +// easier for a developer to be certain that the reading and +// writing code for a type are consistent with one another. +// +// In some cases, however, a serialization function may need +// to know whether it is reading or writing serialized data. +// For that reason, we define a simple `enum` to represent +// the different modes of operation. +// + +/// Whether serialized data is being read or written. +enum class SerializationMode +{ + Read, + Write, +}; -// Predeclare -typedef uint32_t SerialSourceLoc; -class NodeBase; -class Val; -struct ValNodeDesc; +// +// In order to support different serialized formats, and to +// abstract over the difference between reading and writing, +// we define a base interface for serialization. This interface +// is somewhat user-unfriendly, and is *not* intended for +// ordinary code to interface with directly. +// + +/// Base interface for serialization. +/// +/// Can be used for both reading and writing of serialized data. +/// +struct ISerializerImpl +{ + /// Get the mode that this serializer is operating in (reading or writing). + virtual SerializationMode getMode() = 0; + + /// Handle a boolean value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleBool(bool& value) = 0; + + /// Handle an integer value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleInt8(int8_t& value) = 0; + + /// Handle an integer value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleInt16(int16_t& value) = 0; + + /// Handle an integer value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleInt32(Int32& value) = 0; + + /// Handle an integer value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleInt64(Int64& value) = 0; + + /// Handle an integer value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleUInt8(uint8_t& value) = 0; + + /// Handle an integer value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleUInt16(uint16_t& value) = 0; + + /// Handle an integer value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleUInt32(UInt32& value) = 0; + + /// Handle an integer value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleUInt64(UInt64& value) = 0; + + /// Handle a floating-point value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleFloat32(float& value) = 0; + + /// Handle a floating-point value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleFloat64(double& value) = 0; + + /// Handle a string value. + /// + /// If the serializer is writing, then `value` will be + /// written to the serialized format. + /// + /// If the serializer is reading, then `value` will be + /// set to the value read from the serialized format. + /// + virtual void handleString(String& value) = 0; + + /// Begin serializing an array value. + /// + /// An array should be used to serialize an + /// unkeyed homogeneous collection of a varying + /// number of elements. + /// + /// This operation must be properly paired with a + /// call to `endArray()`. + /// + /// When writing, the values serialized between `beginArray()` + /// and `endArray()` will be written as the elements of a + /// serialized array. + /// + /// When reading, the user should call `hasElements()` to + /// test whether there are elements remaining to be read, + /// and serialize values in a loop until `hasElements()` + /// returns `false`. + /// + virtual void beginArray() = 0; + + /// End serializing an array value. + virtual void endArray() = 0; + + /// Begin serializing an optional value. + /// + /// An optional should be used to serialize a + /// collection that logically has either zero + /// or one element. + /// + /// This operation must be properly paired with a + /// call to `endOptional()`. + /// + /// When writing, a value serialized between `beginOptional()` + /// and `endOptional()` will be written as the value of + /// the serialized optional. If no value is serialized, + /// then the optional will be empty. + /// + /// When reading, the user should call `hasElements()` to + /// test whether the serialized optional has a value and, + /// if it does, read the value before calling `endOptional()`. + /// + virtual void beginOptional() = 0; + + /// End serializing an optional value. + virtual void endOptional() = 0; + + /// Begin serializing a dictionary value. + /// + /// A dictionary should be used to serialize a + /// keyed homogeneous collection of a varying + /// number of elements. The elements of a dictioanry + /// are key-value pairs (that is, two-element tuples). + /// + /// Formats are required to support dictionaries with + /// any serializable type as the key, not just strings. + /// + /// This operation must be properly paired with a + /// call to `endDictionary()`. + /// + /// When writing, the values serialized between `beginDictionary()` + /// and `endDictionary()` will be written as the elements of a + /// serialized dictionary. + /// + /// When reading, the user should call `hasElements()` to + /// test whether there are elements remaining to be read, + /// and serialize values in a loop until `hasElements()` + /// returns `false`. + /// + virtual void beginDictionary() = 0; + + /// End serializing a dictionary value. + virtual void endDictionary() = 0; + + /// Check whether there are elements remaining to be read + /// from a serialized container. + /// + /// It is invalid to call this function except between paired + /// `beginArray()`/`endArray()`, beginDictionary()`/`endDictionary()`, + /// or `beginOptional()`/`endOptional()` calls. + /// + virtual bool hasElements() = 0; + + /// Begin serializing a tuple value. + /// + /// A tuple should be used to serialize an + /// unkeyed heterogeneous collection of a fixed + /// number of elements. + /// + /// It is up to the concrete implementation whether calls + /// to `hasElements()` are allowed between `beginTuple()` + /// and `endTuple()`. + /// + virtual void beginTuple() = 0; + + /// End serializing a tuple value. + virtual void endTuple() = 0; + + + /// Begin serializing a struct value. + /// + /// A struct should be used to serialize an + /// keyed heterogeneous collection of a fixed + /// number of elements. + /// + /// The value of each struct field should be + /// preceded by a call to `handleFieldKey()`, + /// which specifies the field being serialized. + /// + /// It is up to the concrete implementation whether + /// a fields can be read in a different order than + /// they were written, and how to handle attempts + /// to read a field that was not written. + /// + virtual void beginStruct() = 0; + + /// End serializing a struct value. + virtual void endStruct() = 0; + + /// Set the key for the next struct field to be serialized. + /// + /// If no name is available for the field, `name` may be `nullptr`. + /// + /// If no index is available for the field, `index` may be `-1`. + /// + /// A user must pass either a valid `name` or `index. + /// + virtual void handleFieldKey(char const* name, Int index) = 0; + + /// A callback function used to handle serialization of pointers. + typedef void (*Callback)(void* valuePtr, void* userData); + + /// Handle a pointer value that is expected to be unique. + /// + /// A unique pointer is logically similar to an optional value. + /// + /// If the pointer value being read/written is null, then + /// the function returns without invoking `callback`. + /// + /// When reading, if the serialized value is non-null, + /// then the callback will be invoked as `callback(&value, userData)`. + /// The callback is expected to read the members of the pointed-to + /// type and set `value` to some object (whether newly constructed + /// or looked up). + /// + /// When writing, if the `value` is non-null, then the callback + /// will be invoked, either immediately or at some later point, + /// as `callback(&ptr, userData)` where `ptr` is a variable + /// holding a copy of the `value` that was passed in. The callback + /// is expected to write the members of the pointed-to type. + /// + /// If the `callback` is invoked at some later point, rather than + /// immediately, the concrete serializer implementation is responsible + /// for ensuring that its internal state has been restored to + /// be compatible with what it was when `handleUniquePtr` was called. + /// + virtual void handleUniquePtr(void*& value, Callback callback, void* userData) = 0; + + /// Handle a pointer value that may have multiple references. + /// + /// This operation is similar to `handleUniquePtr` with the following + /// differences: + /// + /// * When writing, if the same pointer value has been seen before, + /// the `callback` will not be invoked, and instead an additional + /// reference to the previously-serialized value will be written. + /// + /// * When reading, if the serialized value has been read before, + /// the `callback` will not be invoked, and instead `value` will + /// be set to the pointer that was previously read. + /// + virtual void handleSharedPtr(void*& value, Callback callback, void* userData) = 0; + + /// Defer serialization of the contents of an object. + /// + /// Used to delay serialization of members of an object that + /// could cause infinite recursion if serialized eagerly. + /// + /// This operation should only be used in the body of a callback + /// passed to `handleUniquePtr()` or `handleSharedPtr()`. + /// + /// This operation schedules the given `callback` to be called + /// at some later point a `callback(value, userData)`, with + /// the state of the serializer implementation restored to what + /// it was when `handleDeferredObjectContents()` was called. + /// + /// Some concrete serializer implementations might implement + /// this operation by invoking `callback` immediately. + /// + virtual void handleDeferredObjectContents(void* value, Callback callback, void* userData) = 0; +}; -struct Encoder +// +// Rather than interact with instances of `ISerializerImpl` directly, +// most client code will use a wrapper type that amounts to a kind +// of smart pointer. +// +// While the `ISerializerImpl` interface can cover a wide range of +// types that need to be serialized, it is common for types to require +// more specific context to be available in order to perform serialization. +// For example, code might need access to a factory object in order +// to construct objects of a type being read. +// +// To support more specialized serializer implementations, we allow +// the smart pointer used for a serializer to depend on the type +// of the underlying implementation object. +// + +/// Base type for serialization contexts. +/// +/// The type parameter `T` should be a type of object that +/// holds the context information needed. +/// +template<typename T> +struct SerializerBase { public: - Encoder() {} - - Encoder(RIFF::Builder& riff) - : _cursor(riff) + SerializerBase() = default; + SerializerBase(T* ptr) + : _ptr(ptr) { } - Encoder(RIFF::ListChunkBuilder* chunk) - : _cursor(chunk) - { - } + T* get() const { return _ptr; } + T* operator->() const { return get(); } - void beginArray(FourCC typeCode) { _cursor.beginListChunk(typeCode); } +private: + T* _ptr = nullptr; +}; - void beginArray() { beginArray(SerialBinary::kArrayFourCC); } +/// A serialization context. +/// +/// The type parameter `T` should be a type of object that +/// holds the context information needed. +/// +template<typename T> +struct Serializer_ : SerializerBase<T> +{ + using SerializerBase<T>::SerializerBase; +}; - void endArray() { _cursor.endChunk(); } +/// Default serialization context. +using Serializer = Serializer_<ISerializerImpl>; - void beginObject(FourCC typeCode) { _cursor.beginListChunk(typeCode); } +// +// We define namespace-scope functions that mirror some +// of the operations of `ISerializerImpl`, so that they +// can be invoked on any type that is contextually +// convertible to a `Serializer`. This allows users +// to define their own serialization context types while +// still being able to take advantage of the utility +// operations in this file for serializing basic types, +// arrays, dictionaries, etc. +// - void beginObject() { beginObject(SerialBinary::kObjectFourCC); } - void endObject() { _cursor.endChunk(); } +/// Get the mode of `serializer`. +inline SerializationMode getMode(Serializer const& serializer) +{ + return serializer->getMode(); +} - void beginKeyValuePair(FourCC keyCode) { _cursor.beginListChunk(keyCode); } +/// Check if `serializer` is reading serialized data. +inline bool isReading(Serializer const& serializer) +{ + return getMode(serializer) == SerializationMode::Read; +} - void beginKeyValuePair() { beginKeyValuePair(SerialBinary::kPairFourCC); } +/// Check if `serializer` is writing serialized data. +inline bool isWriting(Serializer const& serializer) +{ + return getMode(serializer) == SerializationMode::Write; +} - void endKeyValuePair() { _cursor.endChunk(); } +/// Check if `serializer` has more container elements. +inline bool hasElements(Serializer const& serializer) +{ + return serializer->hasElements(); +} - void encodeData(FourCC typeCode, void const* data, size_t size) - { - _cursor.addDataChunk(typeCode, data, size); - } +inline void serialize(Serializer const& serializer, bool& value) +{ + serializer->handleBool(value); +} - void encodeData(void const* data, size_t size) - { - encodeData(SerialBinary::kDataFourCC, data, size); - } +inline void serialize(Serializer const& serializer, int8_t& value) +{ + serializer->handleInt8(value); +} - void encode(std::nullptr_t) { encodeData(SerialBinary::kNullFourCC, nullptr, 0); } +inline void serialize(Serializer const& serializer, int16_t& value) +{ + serializer->handleInt16(value); +} - void encodeBool(bool value) - { - encodeData(value ? SerialBinary::kTrueFourCC : SerialBinary::kFalseFourCC, nullptr, 0); - } +inline void serialize(Serializer const& serializer, Int32& value) +{ + serializer->handleInt32(value); +} - void encodeInt(Int64 value) - { - if (Int32(value) == value) - { - auto v = Int32(value); - encodeData(SerialBinary::kInt32FourCC, &v, sizeof(v)); - } - else - { - encodeData(SerialBinary::kInt64FourCC, &value, sizeof(value)); - } - } +inline void serialize(Serializer const& serializer, Int64& value) +{ + serializer->handleInt64(value); +} +inline void serialize(Serializer const& serializer, uint8_t& value) +{ + serializer->handleUInt8(value); +} - void encodeUInt(UInt64 value) - { - if (UInt32(value) == value) - { - auto v = UInt32(value); - encodeData(SerialBinary::kUInt32FourCC, &v, sizeof(v)); - } - else - { - encodeData(SerialBinary::kUInt64FourCC, &value, sizeof(value)); - } - } +inline void serialize(Serializer const& serializer, uint16_t& value) +{ + serializer->handleUInt16(value); +} - void encode(Int32 value) { encodeInt(value); } - void encode(Int64 value) { encodeInt(value); } +inline void serialize(Serializer const& serializer, UInt32& value) +{ + serializer->handleUInt32(value); +} - void encode(UInt32 value) { encodeUInt(value); } - void encode(UInt64 value) { encodeUInt(value); } +inline void serialize(Serializer const& serializer, UInt64& value) +{ + serializer->handleUInt64(value); +} - void encode(float value) { encodeData(SerialBinary::kFloat32FourCC, &value, sizeof(value)); } +inline void serialize(Serializer const& serializer, float& value) +{ + serializer->handleFloat32(value); +} - void encode(double value) { encodeData(SerialBinary::kFloat64FourCC, &value, sizeof(value)); } +inline void serialize(Serializer const& serializer, double& value) +{ + serializer->handleFloat64(value); +} - void encodeString(String const& value) +inline void serialize(Serializer const& serializer, String& value) +{ + serializer->handleString(value); +} + +/// Serialize an `enum` value via an intermediate integer type. +/// +/// This function serializes a value of `EnumType`, by +/// converting it to/from the given `RawType` for storage +/// in the serialized format. +/// +template<typename RawType = Int32, typename EnumType> +void serializeEnum(Serializer const& serializer, EnumType& value) +{ + auto raw = RawType(value); + serialize(serializer, raw); + value = EnumType(raw); +} + +// +// We define a suite of simple RAII types to help users +// maintain the proper pairing of begin/end operations +// when interacting with an `ISerializerImpl`, and for +// each of those types we define a macro to simplify +// introducing a coresponding scope. +// + +struct ScopedSerializerArray +{ +public: + ScopedSerializerArray(Serializer const& serializer) + : _serializer(serializer) { - Int size = value.getLength(); - encodeData(SerialBinary::kStringFourCC, value.getBuffer(), size); + serializer->beginArray(); } + ~ScopedSerializerArray() { _serializer->endArray(); } - void encode(String const& value) { encodeString(value); } - - struct WithArray - { - public: - WithArray(Encoder* encoder) - : _encoder(encoder) - { - encoder->beginArray(); - } - - WithArray(Encoder* encoder, FourCC typeCode) - : _encoder(encoder) - { - encoder->beginArray(typeCode); - } - - ~WithArray() { _encoder->endArray(); } - - private: - Encoder* _encoder; - }; +private: + Serializer _serializer; +}; - struct WithObject +struct ScopedSerializerDictionary +{ +public: + ScopedSerializerDictionary(Serializer const& serializer) + : _serializer(serializer) { - public: - WithObject(Encoder* encoder) - : _encoder(encoder) - { - encoder->beginObject(); - } - - WithObject(Encoder* encoder, FourCC typeCode) - : _encoder(encoder) - { - encoder->beginObject(typeCode); - } + serializer->beginDictionary(); + } - ~WithObject() { _encoder->endObject(); } + ~ScopedSerializerDictionary() { _serializer->endDictionary(); } - private: - Encoder* _encoder; - }; +private: + Serializer _serializer; +}; - struct WithKeyValuePair +struct ScopedSerializerStruct +{ +public: + ScopedSerializerStruct(Serializer const& serializer) + : _serializer(serializer) { - public: - WithKeyValuePair(Encoder* encoder) - : _encoder(encoder) - { - encoder->beginKeyValuePair(); - } - - WithKeyValuePair(Encoder* encoder, FourCC typeCode) - : _encoder(encoder) - { - encoder->beginKeyValuePair(typeCode); - } - - ~WithKeyValuePair() { _encoder->endKeyValuePair(); } + serializer->beginStruct(); + } - private: - Encoder* _encoder; - }; + ~ScopedSerializerStruct() { _serializer->endStruct(); } private: - RIFF::BuildCursor _cursor; + Serializer _serializer; +}; +struct ScopedSerializerTuple +{ public: - operator RIFF::BuildCursor&() { return _cursor; } + ScopedSerializerTuple(Serializer const& serializer) + : _serializer(serializer) + { + serializer->beginTuple(); + } - RIFF::ChunkBuilder* getRIFFChunk() { return _cursor.getCurrentChunk(); } + ~ScopedSerializerTuple() { _serializer->endTuple(); } - void setRIFFChunk(RIFF::ChunkBuilder* chunk) { _cursor.setCurrentChunk(chunk); } +private: + Serializer _serializer; }; -struct Decoder +struct ScopedSerializerOptional { public: - Decoder(RIFF::Chunk const* chunk) - : _cursor(chunk) + ScopedSerializerOptional(Serializer const& serializer) + : _serializer(serializer) { + serializer->beginOptional(); } - bool decodeBool() - { - switch (getTag()) - { - case SerialBinary::kTrueFourCC: - _advanceCursor(); - return true; - case SerialBinary::kFalseFourCC: - _advanceCursor(); - return false; - - default: - SLANG_UNEXPECTED("invalid format in RIFF"); - UNREACHABLE_RETURN(false); - } - } - - String decodeString() - { - if (getTag() != SerialBinary::kStringFourCC) - { - SLANG_UNEXPECTED("invalid format in RIFF"); - UNREACHABLE_RETURN(""); - } - - auto dataChunk = as<RIFF::DataChunk>(_cursor); - if (!dataChunk) - { - SLANG_UNEXPECTED("invalid format in RIFF"); - UNREACHABLE_RETURN(""); - } + ~ScopedSerializerOptional() { _serializer->endOptional(); } - auto size = dataChunk->getPayloadSize(); +private: + Serializer _serializer; +}; - String value; - value.appendRepeatedChar(' ', size); - dataChunk->writePayloadInto((char*)value.getBuffer(), size); - _advanceCursor(); - return value; +#define SLANG_SCOPED_SERIALIZER_ARRAY(SERIALIZER) \ + ::Slang::ScopedSerializerArray SLANG_CONCAT(_scopedSerializerArray, __LINE__)(SERIALIZER) + +#define SLANG_SCOPED_SERIALIZER_DICTIONARY(SERIALIZER) \ + ::Slang::ScopedSerializerDictionary SLANG_CONCAT(_scopedSerializerDictionary, __LINE__)( \ + SERIALIZER) + +#define SLANG_SCOPED_SERIALIZER_OPTIONAL(SERIALIZER) \ + ::Slang::ScopedSerializerOptional SLANG_CONCAT(_scopedSerializerOptional, __LINE__)(SERIALIZER) + +#define SLANG_SCOPED_SERIALIZER_STRUCT(SERIALIZER) \ + ::Slang::ScopedSerializerStruct SLANG_CONCAT(_scopedSerializerStruct, __LINE__)(SERIALIZER) + +#define SLANG_SCOPED_SERIALIZER_TAGGED_UNION(SERIALIZER) \ + ::Slang::ScopedSerializerStruct SLANG_CONCAT(_scopedSerializerStruct, __LINE__)(SERIALIZER) + +#define SLANG_SCOPED_SERIALIZER_TUPLE(SERIALIZER) \ + ::Slang::ScopedSerializerTuple SLANG_CONCAT(_scopedSerializerTuple, __LINE__)(SERIALIZER) + +// +// Containers like arrays and dictionaries are more +// difficult to serialize than typical user-defined +// types for a few reasons: +// +// * They typically need to have distinct code paths +// for reading and writing, so they don't benefit +// much from having a unified read/write abstraction. +// +// * They need to be written as templates, to abstract +// over the element type, and thus need to be +// defined in headers. +// +// * Because the element type might require a more +// specialized type of serialization context, they +// also need to be templated on the type of the +// serializer itself. +// +// With all that said, the definitions themselves +// are fairly straightforward. All we have to do is +// branch on whether we are reading or writing and +// either iterate over the serialized data to fill +// the collection (when reading), or iterate over +// the collection to serialize its elements (when +// writing). +// + +template<typename S, typename T> +void serialize(S const& serializer, List<T>& value) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) + { + for (auto element : value) + serialize(serializer, element); } - - void decodeData(FourCC typeTag, void* outData, size_t dataSize) + else { - if (getTag() == typeTag) + value.clear(); + while (hasElements(serializer)) { - auto dataChunk = as<RIFF::DataChunk>(_cursor); - if (dataChunk) - { - auto payloadSize = dataChunk->getPayloadSize(); - if (payloadSize >= dataSize) - { - dataChunk->writePayloadInto(outData, dataSize); - _advanceCursor(); - return; - } - } + T element; + serialize(serializer, element); + value.add(element); } - - SLANG_UNEXPECTED("invalid format in RIFF"); } +} - template<typename T> - T _decodeSimpleValue(FourCC typeTag) +template<typename S, typename T, size_t N> +void serialize(S const& serializer, T (&value)[N]) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) { - T value; - decodeData(typeTag, &value, sizeof(value)); - return value; + for (auto element : value) + serialize(serializer, element); } - - Int64 decodeInt() + else { - switch (getTag()) + size_t index = 0; + while (hasElements(serializer)) { - case SerialBinary::kInt64FourCC: - return _decodeSimpleValue<Int64>(getTag()); - case SerialBinary::kInt32FourCC: - return _decodeSimpleValue<Int32>(getTag()); - - case SerialBinary::kUInt32FourCC: - return _decodeSimpleValue<UInt32>(getTag()); + T element; + serialize(serializer, element); - case SerialBinary::kUInt64FourCC: + if (index >= N) { - auto uintValue = _decodeSimpleValue<UInt64>(getTag()); - if (Int64(uintValue) < 0) - { - SLANG_UNEXPECTED("signed/unsigned mismatch in RIFF"); - } - return Int64(uintValue); + SLANG_UNEXPECTED("serialized array too large"); } - - default: - SLANG_UNEXPECTED("invalid format in RIFF"); - UNREACHABLE_RETURN(0); + value[index++] = element; } } +} - UInt64 decodeUInt() +template<typename S, typename T, int N> +void serialize(S const& serializer, ShortList<T, N>& value) +{ + SLANG_SCOPED_SERIALIZER_ARRAY(serializer); + if (isWriting(serializer)) { - switch (getTag()) - { - case SerialBinary::kUInt64FourCC: - return _decodeSimpleValue<UInt64>(getTag()); - case SerialBinary::kUInt32FourCC: - return _decodeSimpleValue<UInt32>(getTag()); - - case SerialBinary::kInt32FourCC: - case SerialBinary::kInt64FourCC: - { - auto intValue = decodeInt(); - if (intValue < 0) - { - SLANG_UNEXPECTED("signed/unsigned mismatch in RIFF"); - } - return UInt64(intValue); - } - - default: - SLANG_UNEXPECTED("invalid format in RIFF"); - UNREACHABLE_RETURN(0); - } + for (auto element : value) + serialize(serializer, element); } - - double decodeFloat() + else { - switch (getTag()) + value.clear(); + while (hasElements(serializer)) { - case SerialBinary::kFloat32FourCC: - return _decodeSimpleValue<float>(getTag()); - case SerialBinary::kFloat64FourCC: - return _decodeSimpleValue<double>(getTag()); - - default: - SLANG_UNEXPECTED("invalid format in RIFF"); - UNREACHABLE_RETURN(0); + T element; + serialize(serializer, element); + value.add(element); } } +} - Int32 decodeInt32() { return Int32(decodeInt()); } - Int64 decodeInt64() { return decodeInt(); } - - UInt32 decodeUInt32() { return UInt32(decodeUInt()); } - UInt64 decodeUInt64() { return decodeUInt(); } - - float decodeFloat32() { return float(decodeFloat()); } - double decodeFloat64() { return decodeFloat(); } - - FourCC getTag() { return _cursor ? _cursor->getType() : FourCC(0); } - - Int32 _decodeImpl(Int32*) { return decodeInt32(); } - UInt32 _decodeImpl(UInt32*) { return decodeUInt32(); } - - Int64 _decodeImpl(Int64*) { return decodeInt64(); } - UInt64 _decodeImpl(UInt64*) { return decodeUInt64(); } - - float _decodeImpl(float*) { return decodeFloat32(); } - double _decodeImpl(double*) { return decodeFloat64(); } - - template<typename T> - T decode() - { - return _decodeImpl((T*)nullptr); - } - - template<typename T> - void decode(T& outValue) - { - outValue = _decodeImpl((T*)nullptr); - } - - void beginArray(FourCC typeCode = SerialBinary::kArrayFourCC) +template<typename S, typename T> +void serialize(S const& serializer, std::optional<T>& value) +{ + SLANG_SCOPED_SERIALIZER_OPTIONAL(serializer); + if (isWriting(serializer)) { - auto listChunk = as<RIFF::ListChunk>(_cursor); - if (!listChunk) + if (value.has_value()) { - SLANG_UNEXPECTED("invalid format in RIFF"); + serialize(serializer, *value); } - - if (listChunk->getType() != typeCode) - { - SLANG_UNEXPECTED("invalid format in RIFF"); - } - - _cursor = listChunk->getFirstChild(); } - - void beginObject(FourCC typeCode = SerialBinary::kObjectFourCC) + else { - auto listChunk = as<RIFF::ListChunk>(_cursor); - if (!listChunk) - { - SLANG_UNEXPECTED("invalid format in RIFF"); - } - - if (listChunk->getType() != typeCode) + value.reset(); + if (hasElements(serializer)) { - SLANG_UNEXPECTED("invalid format in RIFF"); + value.emplace(); + serialize(serializer, *value); } - - _cursor = listChunk->getFirstChild(); } +} - void beginKeyValuePair(FourCC typeCode = SerialBinary::kPairFourCC) - { - auto listChunk = as<RIFF::ListChunk>(_cursor); - if (!listChunk) - { - SLANG_UNEXPECTED("invalid format in RIFF"); - } +template<typename S, typename K, typename V> +void serialize(S const& serializer, KeyValuePair<K, V>& value) +{ + SLANG_SCOPED_SERIALIZER_TUPLE(serializer); + serialize(serializer, value.key); + serialize(serializer, value.value); +} - if (listChunk->getType() != typeCode) - { - SLANG_UNEXPECTED("invalid format in RIFF"); - } +template<typename S, typename K, typename V> +void serialize(S const& serializer, std::pair<K, V>& value) +{ + SLANG_SCOPED_SERIALIZER_TUPLE(serializer); + serialize(serializer, value.first); + serialize(serializer, value.second); +} - _cursor = listChunk->getFirstChild(); +template<typename S, typename K, typename V> +void serialize(S const& serializer, Dictionary<K, V>& value) +{ + SLANG_SCOPED_SERIALIZER_DICTIONARY(serializer); + if (isWriting(serializer)) + { + for (auto pair : value) + serialize(serializer, pair); } - - void beginProperty(FourCC propertyCode) + else { - auto listChunk = as<RIFF::ListChunk>(_cursor); - if (!listChunk) - { - SLANG_UNEXPECTED("invalid format in RIFF"); - } - - auto found = listChunk->findListChunk(propertyCode); - if (!found) + value.clear(); + while (hasElements(serializer)) { - SLANG_UNEXPECTED("invalid format in RIFF"); + KeyValuePair<K, V> pair; + serialize(serializer, pair); + value.add(pair.key, pair.value); } - - _cursor = found->getFirstChild(); } +} - bool hasElements() { return _cursor != nullptr; } - - bool isNull() +template<typename S, typename K, typename V> +void serialize(S const& serializer, OrderedDictionary<K, V>& value) +{ + SLANG_SCOPED_SERIALIZER_DICTIONARY(serializer); + if (isWriting(serializer)) { - if (_cursor == nullptr) - return true; - if (getTag() == SerialBinary::kNullFourCC) - return true; - return false; + for (auto pair : value) + serialize(serializer, pair); } - - bool decodeNull() + else { - if (!isNull()) - return false; - - if (_cursor != nullptr) + value.clear(); + while (hasElements(serializer)) { - _advanceCursor(); + KeyValuePair<K, V> pair; + serialize(serializer, pair); + value.add(pair.key, pair.value); } - return true; } +} + +// +// Serialization of pointers is the most complicated part of +// the whole system. Dealing with pointers means contending with: +// +// * Multiply-referenced objects, or even cycles in the object graph. +// +// * Polymoprhic types, where a `Derived*` might get serialized +// through a `Base*` pointer. +// +// * Types that require going through a factory function of +// some kind as part of their creation (perhaps to implement +// deduplication/caching). +// +// Our handling of pointers is thus broken down into several +// different steps/layers: +// +// * An ordinary overload of `serialize(s,v)` is used to intercept +// pointer types `T*` and dispatched out to `serializePtr(s,v,(T*)nullptr)`. +// Passing the additional `T*` argument allows different overloads +// of `serializePtr` to intercept entire type hierarchies, while +// still allowing for a fallback case. +// +// * Implementations of `serializePtr` are typically expected to +// invoke either `serializeUniquePtr` or `serializeSharedPtr`, which +// handle calling into the `ISerializerImpl` methods with appropriate +// callbacks. +// +// * The `handleUniquePtr()` or `handleSharedPtr()` operation on +// `ISerializerImpl` is expected to handle null pointers, or previously- +// encountered pointers in the shared case, and then invoke the +// callback to handle things when it can't early-out. +// +// * The callbacks will end up calling `serializeObject(s,v,(T*)nullptr)`, +// which is another customization point. The default implementation +// will call `new T()` when reading, so types that need more complicated +// creation logic should intercept this specialization point. +// +// * An implementation of `serializeObject()` should strive to serialize +// the bare minimum of members required to actually allocate the object +// (in the case where serialized data is being read), and then call +// `deferSerializeObjectContents()` to schedule the remainder of +// the data to be serialized. Maintaining that policy helps ensure +// that cycles in the object graph don't create problems. +// +// * `serializeObjectContents()` is the final customization point. By +// default it simply takes a `T* value` and does `serialize(..., *value)` +// to serialize the pointed-to `T` value. A custom implementation +// should serialize whatever members of the object weren't handled +// as part of the corresponding `serializeObject()` implementation. +// + +template<typename S, typename T> +void serializeObjectContents(S const& serializer, T* value, void*) +{ + serialize(serializer, *value); +} - using Cursor = RIFF::BoundsCheckedChunkPtr; - - struct WithArray - { - public: - WithArray(Decoder& decoder) - : _decoder(decoder) - { - _saved = decoder._cursor; - decoder.beginArray(); - } - - WithArray(Decoder& decoder, FourCC typeCode) - : _decoder(decoder) - { - _saved = decoder._cursor; - decoder.beginArray(typeCode); - } - - ~WithArray() { _decoder._cursor = _saved.getNextSibling(); } - - private: - Cursor _saved; - Decoder& _decoder; - }; - - struct WithObject - { - public: - WithObject(Decoder& decoder) - : _decoder(decoder) - { - _saved = decoder._cursor; - decoder.beginObject(); - } - - WithObject(Decoder& decoder, FourCC typeCode) - : _decoder(decoder) - { - _saved = decoder._cursor; - decoder.beginObject(typeCode); - } - - ~WithObject() { _decoder._cursor = _saved.getNextSibling(); } - - private: - Cursor _saved; - Decoder& _decoder; - }; - - struct WithKeyValuePair - { - public: - WithKeyValuePair(Decoder& decoder) - : _decoder(decoder) - { - _saved = decoder._cursor; - decoder.beginKeyValuePair(); - } - - WithKeyValuePair(Decoder& decoder, FourCC typeCode) - : _decoder(decoder) - { - _saved = decoder._cursor; - _decoder.beginKeyValuePair(typeCode); - } - - ~WithKeyValuePair() { _decoder._cursor = _saved.getNextSibling(); } - - private: - Cursor _saved; - Decoder& _decoder; - }; +template<typename S, typename T> +void _serializeObjectContentsCallback(void* valuePtr, void* userData) +{ + auto serializerImpl = (S*)userData; + auto value = (T*)valuePtr; + serializeObjectContents(Serializer_<S>(serializerImpl), value, (T*)nullptr); +} - struct WithProperty +template<typename S, typename T> +void deferSerializeObjectContents(Serializer_<S> const& serializer, T* value) +{ + ((Serializer)serializer) + ->handleDeferredObjectContents( + value, + _serializeObjectContentsCallback<S, T>, + serializer.get()); +} + +template<typename S, typename T> +void serializeObject(S const& serializer, T*& value, void*) +{ + if (isReading(serializer)) { - public: - WithProperty(Decoder& decoder, FourCC typeCode) - : _decoder(decoder) - { - _saved = decoder._cursor; - _decoder.beginProperty(typeCode); - } - - ~WithProperty() { _decoder._cursor = _saved.getNextSibling(); } + value = new T(); + } + deferSerializeObjectContents(serializer, value); +} - private: - Cursor _saved; - Decoder& _decoder; - }; +template<typename S, typename T> +void _serializeObjectCallback(void* valuePtr, void* userData) +{ + auto serializerImpl = (S*)userData; + auto& value = *(T**)valuePtr; + serializeObject(Serializer_<S>(serializerImpl), value, (T*)nullptr); +} - Cursor getCursor() const { return _cursor; } - void setCursor(Cursor const& cursor) { _cursor = cursor; } +template<typename S, typename T> +void serializeSharedPtr(Serializer_<S> const& serializer, T*& value) +{ + ((Serializer)serializer) + ->handleSharedPtr(*(void**)&value, _serializeObjectCallback<S, T>, serializer.get()); +} - RIFF::Chunk const* getCurrentChunk() const { return getCursor(); } +template<typename S, typename T> +void serializeUniquePtr(Serializer_<S> const& serializer, T*& value) +{ + ((Serializer)serializer) + ->handleUniquePtr(*(void**)&value, _serializeObjectCallback<S, T>, serializer.get()); +} -private: - void _advanceCursor() { _cursor = _cursor.getNextSibling(); } +template<typename S, typename T> +void serializePtr(S const& serializer, T*& value, void*) +{ + serializeSharedPtr(serializer, value); +} - Cursor _cursor; -}; +template<typename S, typename T> +void serialize(S const& serializer, T*& value) +{ + serializePtr(serializer, value, (T*)nullptr); +} +template<typename S, typename T> +void serialize(S const& serializer, RefPtr<T>& value) +{ + T* raw = value; + serialize(serializer, raw); + value = raw; +} } // namespace Slang |
