diff options
| author | Yong He <yonghe@outlook.com> | 2024-02-15 00:05:51 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-02-15 00:05:51 -0800 |
| commit | 5a623ec227726ad1d988a5d91f55f19b62a98e03 (patch) | |
| tree | 94a3fd2f00ce1a95035f39cd3571c9e97a70d24e /source/slang/slang-serialize.cpp | |
| parent | 2ced683f10fb82f63a2e2c3d7b5f099c53bb57b0 (diff) | |
Support loading serialized modules. (#3588)
* Support loading serialized modules.
* Fix.
* Fix vs solution files
* Fix glsl module loading.
* C++ fix.
* Fix.
* Try fix c++ error.
* Try fix.
* Fix.
* Fix.
Diffstat (limited to 'source/slang/slang-serialize.cpp')
| -rw-r--r-- | source/slang/slang-serialize.cpp | 167 |
1 files changed, 166 insertions, 1 deletions
diff --git a/source/slang/slang-serialize.cpp b/source/slang/slang-serialize.cpp index 1f5b6942d..1c8abb8f9 100644 --- a/source/slang/slang-serialize.cpp +++ b/source/slang/slang-serialize.cpp @@ -3,6 +3,7 @@ #include "slang-ast-base.h" #include "slang-ast-builder.h" +#include "slang-check-impl.h" namespace Slang { @@ -222,6 +223,50 @@ SerialWriter::SerialWriter(SerialClasses* classes, SerialFilter* filter, Flags f m_ptrMap.add(nullptr, 0); } +struct SkipFunctionBodyRAII +{ + FunctionDeclBase* funcDecl = nullptr; + Stmt* oldBody = nullptr; + SkipFunctionBodyRAII(SerialWriter::Flags flags, const SerialClass* serialCls, const void* ptr) + { + if ((flags & SerialWriter::Flag::SkipFunctionBody) == 0) + return; + + if (serialCls->typeKind != SerialTypeKind::NodeBase) + return; + auto cls = serialCls; + while (cls) + { + auto astNodeType = (ASTNodeType)cls->subType; + if (astNodeType == ASTNodeType::FunctionDeclBase) + { + funcDecl = (FunctionDeclBase*)ptr; + break; + } + cls = cls->super; + } + if (funcDecl) + { + oldBody = funcDecl->body; + // We always need to include body of unsafeForceInlineEarly functions + // since they will need to be available at IR lowering time of the + // user module for pre-linking inling. + if (!isUnsafeForceInlineFunc(funcDecl)) + { + funcDecl->body = nullptr; + } + } + + } + ~SkipFunctionBodyRAII() + { + if (funcDecl) + { + funcDecl->body = oldBody; + } + } +}; + SerialIndex SerialWriter::writeObject(const SerialClass* serialCls, const void* ptr) { if (serialCls->flags & SerialClassFlag::DontSerialize) @@ -229,6 +274,16 @@ SerialIndex SerialWriter::writeObject(const SerialClass* serialCls, const void* return SerialIndex(0); } + if (serialCls->typeKind == SerialTypeKind::NodeBase && + ReflectClassInfo::isSubClassOf(serialCls->subType, Val::kReflectClassInfo)) + { + return writeValObject((Val*)ptr); + } + + // If we are skipping function bodies, set the body field to nullptr, and + // restore it after serialization. + SkipFunctionBodyRAII clearFunctionBodyRAII(m_flags, serialCls, ptr); + // This pointer cannot be in the map SLANG_ASSERT(m_ptrMap.tryGetValue(ptr) == nullptr); @@ -279,6 +334,62 @@ SerialIndex SerialWriter::writeObject(const NodeBase* node) return writeObject(serialClass, (const void*)node); } +SerialIndex SerialWriter::writeValObject(const Val* node) +{ + typedef SerialInfo::ValEntry ValEntry; + + size_t size = node->getOperandCount() * sizeof(SerialInfo::SerialValOperand); + ValEntry* nodeEntry = (ValEntry*)m_arena.allocateAligned(sizeof(ValEntry) + size, SerialInfo::MAX_ALIGNMENT); + + nodeEntry->typeKind = SerialTypeKind::NodeBase; + nodeEntry->subType = (SerialSubType)node->astNodeType; + nodeEntry->operandCount = (uint32_t)node->getOperandCount(); + nodeEntry->info = SerialInfo::makeEntryInfo(SerialInfo::MAX_ALIGNMENT); + + // We add before adding fields, so if the fields point to this, the entry will be set + auto index = _add(node, nodeEntry); + + ShortList<SerialIndex, 4> serializedOperands; + + for (Index i = 0; i < node->getOperandCount(); i++) + { + auto operand = node->m_operands[i]; + switch (operand.kind) + { + case ValNodeOperandKind::ConstantValue: + serializedOperands.add((SerialIndex)0); + break; + case ValNodeOperandKind::ValNode: + case ValNodeOperandKind::ASTNode: + serializedOperands.add(addPointer(operand.values.nodeOperand)); + break; + } + } + + SLANG_ASSERT(serializedOperands.getCount() == node->getOperandCount()); + + auto serialOperands = (SerialInfo::SerialValOperand*)(nodeEntry + 1); + for (Index i = 0; i < node->getOperandCount(); i++) + { + auto serialOperand = serialOperands + i; + auto operand = node->m_operands[i]; + serialOperand->type = (int)operand.kind; + switch (operand.kind) + { + case ValNodeOperandKind::ConstantValue: + serialOperand->payload = operand.values.intOperand; + break; + case ValNodeOperandKind::ValNode: + serialOperand->payload = (uint64_t)serializedOperands[i]; + break; + case ValNodeOperandKind::ASTNode: + serialOperand->payload = (uint64_t)serializedOperands[i]; + break; + } + } + return index; +} + SerialIndex SerialWriter::writeObject(const RefObject* obj) { const SerialRefObject* serialObj = as<const SerialRefObject>(obj); @@ -633,6 +744,9 @@ size_t SerialInfo::Entry::calcSize(SerialClasses* serialClasses) const auto serialClass = serialClasses->getSerialClass(typeKind, entry->subType); + if (ReflectClassInfo::isSubClassOf(entry->subType, Val::kReflectClassInfo)) + return sizeof(ValEntry) + static_cast<const ValEntry*>(this)->operandCount * sizeof(SerialValOperand); + // Align by the alignment of the entry size_t alignment = getAlignment(entry->info); size_t size = sizeof(ObjectEntry) + serialClass->size; @@ -722,6 +836,49 @@ SerialPointer SerialReader::getPointer(SerialIndex index) return ptr; } +SerialPointer SerialReader::getValPointer(SerialIndex index) +{ + if (index == SerialIndex(0)) + { + return SerialPointer(); + } + + SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount())); + + SerialPointer& ptr = m_objects[Index(index)]; + + if (ptr.m_ptr) + return ptr; + + const SerialInfo::ValEntry* entry = (SerialInfo::ValEntry*)m_entries[Index(index)]; + ValNodeDesc desc; + desc.type = (ASTNodeType)entry->subType; + auto readPtr = (SerialInfo::SerialValOperand*)(entry + 1); + for (uint32_t i = 0; i < entry->operandCount; i++) + { + auto serialOperand = readPtr[i]; + ValNodeOperand operand; + operand.kind = (ValNodeOperandKind)(serialOperand.type); + switch (operand.kind) + { + case ValNodeOperandKind::ConstantValue: + operand.values.intOperand = serialOperand.payload; + break; + case ValNodeOperandKind::ASTNode: + operand.values.nodeOperand = (NodeBase*)getPointer((SerialIndex)serialOperand.payload).m_ptr; + break; + case ValNodeOperandKind::ValNode: + operand.values.nodeOperand = (Val*)getValPointer((SerialIndex)serialOperand.payload).m_ptr; + break; + } + desc.operands.add(operand); + } + desc.init(); + ptr.m_kind = SerialTypeKind::NodeBase; + ptr.m_ptr = this->m_objectFactory->getOrCreateVal(_Move(desc)); + return ptr; +} + String SerialReader::getString(SerialIndex index) { if (index == SerialIndex(0)) @@ -902,6 +1059,12 @@ SlangResult SerialReader::constructObjects(NamePool* namePool) case SerialTypeKind::NodeBase: { auto objectEntry = static_cast<const SerialInfo::ObjectEntry*>(entry); + + // Don't create object for Vals. + if (objectEntry->typeKind == SerialTypeKind::NodeBase && + ReflectClassInfo::isSubClassOf(objectEntry->subType, Val::kReflectClassInfo)) + break; + void* obj = m_objectFactory->create(objectEntry->typeKind, objectEntry->subType); if (!obj) { @@ -912,7 +1075,7 @@ SlangResult SerialReader::constructObjects(NamePool* namePool) } case SerialTypeKind::Array: { - // Don't need to construct an object, as will be accessed an interpreted by the object that holds it + // Don't need to construct an object, as will be accessed and interpreted by the object that holds it break; } } @@ -944,6 +1107,8 @@ SlangResult SerialReader::deserializeObjects() { return SLANG_FAIL; } + if (ReflectClassInfo::isSubClassOf(objectEntry->subType, Val::kReflectClassInfo)) + continue; const uint8_t* src = (const uint8_t*)(objectEntry + 1); uint8_t* dst = (uint8_t*)dstPtr.m_ptr; |
