summaryrefslogtreecommitdiff
path: root/source/slang/slang-serialize.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-02-15 00:05:51 -0800
committerGitHub <noreply@github.com>2024-02-15 00:05:51 -0800
commit5a623ec227726ad1d988a5d91f55f19b62a98e03 (patch)
tree94a3fd2f00ce1a95035f39cd3571c9e97a70d24e /source/slang/slang-serialize.cpp
parent2ced683f10fb82f63a2e2c3d7b5f099c53bb57b0 (diff)
Support loading serialized modules. (#3588)
* Support loading serialized modules. * Fix. * Fix vs solution files * Fix glsl module loading. * C++ fix. * Fix. * Try fix c++ error. * Try fix. * Fix. * Fix.
Diffstat (limited to 'source/slang/slang-serialize.cpp')
-rw-r--r--source/slang/slang-serialize.cpp167
1 files changed, 166 insertions, 1 deletions
diff --git a/source/slang/slang-serialize.cpp b/source/slang/slang-serialize.cpp
index 1f5b6942d..1c8abb8f9 100644
--- a/source/slang/slang-serialize.cpp
+++ b/source/slang/slang-serialize.cpp
@@ -3,6 +3,7 @@
#include "slang-ast-base.h"
#include "slang-ast-builder.h"
+#include "slang-check-impl.h"
namespace Slang {
@@ -222,6 +223,50 @@ SerialWriter::SerialWriter(SerialClasses* classes, SerialFilter* filter, Flags f
m_ptrMap.add(nullptr, 0);
}
+struct SkipFunctionBodyRAII
+{
+ FunctionDeclBase* funcDecl = nullptr;
+ Stmt* oldBody = nullptr;
+ SkipFunctionBodyRAII(SerialWriter::Flags flags, const SerialClass* serialCls, const void* ptr)
+ {
+ if ((flags & SerialWriter::Flag::SkipFunctionBody) == 0)
+ return;
+
+ if (serialCls->typeKind != SerialTypeKind::NodeBase)
+ return;
+ auto cls = serialCls;
+ while (cls)
+ {
+ auto astNodeType = (ASTNodeType)cls->subType;
+ if (astNodeType == ASTNodeType::FunctionDeclBase)
+ {
+ funcDecl = (FunctionDeclBase*)ptr;
+ break;
+ }
+ cls = cls->super;
+ }
+ if (funcDecl)
+ {
+ oldBody = funcDecl->body;
+ // We always need to include body of unsafeForceInlineEarly functions
+ // since they will need to be available at IR lowering time of the
+ // user module for pre-linking inling.
+ if (!isUnsafeForceInlineFunc(funcDecl))
+ {
+ funcDecl->body = nullptr;
+ }
+ }
+
+ }
+ ~SkipFunctionBodyRAII()
+ {
+ if (funcDecl)
+ {
+ funcDecl->body = oldBody;
+ }
+ }
+};
+
SerialIndex SerialWriter::writeObject(const SerialClass* serialCls, const void* ptr)
{
if (serialCls->flags & SerialClassFlag::DontSerialize)
@@ -229,6 +274,16 @@ SerialIndex SerialWriter::writeObject(const SerialClass* serialCls, const void*
return SerialIndex(0);
}
+ if (serialCls->typeKind == SerialTypeKind::NodeBase &&
+ ReflectClassInfo::isSubClassOf(serialCls->subType, Val::kReflectClassInfo))
+ {
+ return writeValObject((Val*)ptr);
+ }
+
+ // If we are skipping function bodies, set the body field to nullptr, and
+ // restore it after serialization.
+ SkipFunctionBodyRAII clearFunctionBodyRAII(m_flags, serialCls, ptr);
+
// This pointer cannot be in the map
SLANG_ASSERT(m_ptrMap.tryGetValue(ptr) == nullptr);
@@ -279,6 +334,62 @@ SerialIndex SerialWriter::writeObject(const NodeBase* node)
return writeObject(serialClass, (const void*)node);
}
+SerialIndex SerialWriter::writeValObject(const Val* node)
+{
+ typedef SerialInfo::ValEntry ValEntry;
+
+ size_t size = node->getOperandCount() * sizeof(SerialInfo::SerialValOperand);
+ ValEntry* nodeEntry = (ValEntry*)m_arena.allocateAligned(sizeof(ValEntry) + size, SerialInfo::MAX_ALIGNMENT);
+
+ nodeEntry->typeKind = SerialTypeKind::NodeBase;
+ nodeEntry->subType = (SerialSubType)node->astNodeType;
+ nodeEntry->operandCount = (uint32_t)node->getOperandCount();
+ nodeEntry->info = SerialInfo::makeEntryInfo(SerialInfo::MAX_ALIGNMENT);
+
+ // We add before adding fields, so if the fields point to this, the entry will be set
+ auto index = _add(node, nodeEntry);
+
+ ShortList<SerialIndex, 4> serializedOperands;
+
+ for (Index i = 0; i < node->getOperandCount(); i++)
+ {
+ auto operand = node->m_operands[i];
+ switch (operand.kind)
+ {
+ case ValNodeOperandKind::ConstantValue:
+ serializedOperands.add((SerialIndex)0);
+ break;
+ case ValNodeOperandKind::ValNode:
+ case ValNodeOperandKind::ASTNode:
+ serializedOperands.add(addPointer(operand.values.nodeOperand));
+ break;
+ }
+ }
+
+ SLANG_ASSERT(serializedOperands.getCount() == node->getOperandCount());
+
+ auto serialOperands = (SerialInfo::SerialValOperand*)(nodeEntry + 1);
+ for (Index i = 0; i < node->getOperandCount(); i++)
+ {
+ auto serialOperand = serialOperands + i;
+ auto operand = node->m_operands[i];
+ serialOperand->type = (int)operand.kind;
+ switch (operand.kind)
+ {
+ case ValNodeOperandKind::ConstantValue:
+ serialOperand->payload = operand.values.intOperand;
+ break;
+ case ValNodeOperandKind::ValNode:
+ serialOperand->payload = (uint64_t)serializedOperands[i];
+ break;
+ case ValNodeOperandKind::ASTNode:
+ serialOperand->payload = (uint64_t)serializedOperands[i];
+ break;
+ }
+ }
+ return index;
+}
+
SerialIndex SerialWriter::writeObject(const RefObject* obj)
{
const SerialRefObject* serialObj = as<const SerialRefObject>(obj);
@@ -633,6 +744,9 @@ size_t SerialInfo::Entry::calcSize(SerialClasses* serialClasses) const
auto serialClass = serialClasses->getSerialClass(typeKind, entry->subType);
+ if (ReflectClassInfo::isSubClassOf(entry->subType, Val::kReflectClassInfo))
+ return sizeof(ValEntry) + static_cast<const ValEntry*>(this)->operandCount * sizeof(SerialValOperand);
+
// Align by the alignment of the entry
size_t alignment = getAlignment(entry->info);
size_t size = sizeof(ObjectEntry) + serialClass->size;
@@ -722,6 +836,49 @@ SerialPointer SerialReader::getPointer(SerialIndex index)
return ptr;
}
+SerialPointer SerialReader::getValPointer(SerialIndex index)
+{
+ if (index == SerialIndex(0))
+ {
+ return SerialPointer();
+ }
+
+ SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount()));
+
+ SerialPointer& ptr = m_objects[Index(index)];
+
+ if (ptr.m_ptr)
+ return ptr;
+
+ const SerialInfo::ValEntry* entry = (SerialInfo::ValEntry*)m_entries[Index(index)];
+ ValNodeDesc desc;
+ desc.type = (ASTNodeType)entry->subType;
+ auto readPtr = (SerialInfo::SerialValOperand*)(entry + 1);
+ for (uint32_t i = 0; i < entry->operandCount; i++)
+ {
+ auto serialOperand = readPtr[i];
+ ValNodeOperand operand;
+ operand.kind = (ValNodeOperandKind)(serialOperand.type);
+ switch (operand.kind)
+ {
+ case ValNodeOperandKind::ConstantValue:
+ operand.values.intOperand = serialOperand.payload;
+ break;
+ case ValNodeOperandKind::ASTNode:
+ operand.values.nodeOperand = (NodeBase*)getPointer((SerialIndex)serialOperand.payload).m_ptr;
+ break;
+ case ValNodeOperandKind::ValNode:
+ operand.values.nodeOperand = (Val*)getValPointer((SerialIndex)serialOperand.payload).m_ptr;
+ break;
+ }
+ desc.operands.add(operand);
+ }
+ desc.init();
+ ptr.m_kind = SerialTypeKind::NodeBase;
+ ptr.m_ptr = this->m_objectFactory->getOrCreateVal(_Move(desc));
+ return ptr;
+}
+
String SerialReader::getString(SerialIndex index)
{
if (index == SerialIndex(0))
@@ -902,6 +1059,12 @@ SlangResult SerialReader::constructObjects(NamePool* namePool)
case SerialTypeKind::NodeBase:
{
auto objectEntry = static_cast<const SerialInfo::ObjectEntry*>(entry);
+
+ // Don't create object for Vals.
+ if (objectEntry->typeKind == SerialTypeKind::NodeBase &&
+ ReflectClassInfo::isSubClassOf(objectEntry->subType, Val::kReflectClassInfo))
+ break;
+
void* obj = m_objectFactory->create(objectEntry->typeKind, objectEntry->subType);
if (!obj)
{
@@ -912,7 +1075,7 @@ SlangResult SerialReader::constructObjects(NamePool* namePool)
}
case SerialTypeKind::Array:
{
- // Don't need to construct an object, as will be accessed an interpreted by the object that holds it
+ // Don't need to construct an object, as will be accessed and interpreted by the object that holds it
break;
}
}
@@ -944,6 +1107,8 @@ SlangResult SerialReader::deserializeObjects()
{
return SLANG_FAIL;
}
+ if (ReflectClassInfo::isSubClassOf(objectEntry->subType, Val::kReflectClassInfo))
+ continue;
const uint8_t* src = (const uint8_t*)(objectEntry + 1);
uint8_t* dst = (uint8_t*)dstPtr.m_ptr;