summaryrefslogtreecommitdiff
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-07 15:00:38 -0700
committerGitHub <noreply@github.com>2023-08-07 15:00:38 -0700
commit9eb6a84285c1597d723be13924a7ad2991cf717f (patch)
treead4358fb9dcbbd4b561670d02671859a217ad14a /source/slang
parent9ef9cc00d98d1775f0ad86efd246ca1605b3b3e4 (diff)
Fix `Val` deduplication bug. (#3050)
* Fix `Val` deduplication bug. * Fix * Concat stdlib files into a single module. * Remove unnecessary logic in `resolve`. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-base.h26
-rw-r--r--source/slang/slang-ast-builder.h25
-rw-r--r--source/slang/slang-ast-decl-ref.cpp2
-rw-r--r--source/slang/slang-ast-val.cpp68
-rw-r--r--source/slang/slang-check-decl.cpp4
-rw-r--r--source/slang/slang-serialize-ast-type-info.h39
-rw-r--r--source/slang/slang-serialize-container.cpp41
-rw-r--r--source/slang/slang-serialize-type-info.h28
-rw-r--r--source/slang/slang-serialize.h12
-rw-r--r--source/slang/slang-syntax.cpp7
-rw-r--r--source/slang/slang.cpp14
11 files changed, 103 insertions, 163 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index d8f4c8c6c..399097b3e 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -46,8 +46,6 @@ class NodeBase
/// The actual type is set when constructed on the ASTBuilder.
ASTNodeType astNodeType = ASTNodeType(-1);
- // Handy when debugging, shouldn't be checked in though!
- // virtual ~NodeBase() {}
#ifdef _DEBUG
SLANG_UNREFLECTED int32_t _debugUID = 0;
#endif
@@ -197,7 +195,25 @@ struct ValNodeDesc
ASTNodeType type;
ShortList<ValNodeOperand, 4> operands;
- bool operator==(ValNodeDesc const& that) const;
+ inline bool operator==(ValNodeDesc const& that) const
+ {
+ if (hashCode != that.hashCode) return false;
+ if (type != that.type) return false;
+ if (operands.getCount() != that.operands.getCount()) return false;
+ for (Index i = 0; i < operands.getCount(); ++i)
+ {
+ // Note: we are comparing the operands directly for identity
+ // (pointer equality) rather than doing the `Val`-level
+ // equality check.
+ //
+ // The rationale here is that nodes that will be created
+ // via a `NodeDesc` *should* all be going through the
+ // deduplication path anyway, as should their operands.
+ //
+ if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false;
+ }
+ return true;
+ }
HashCode getHashCode() const { return hashCode; }
void init();
private:
@@ -430,6 +446,10 @@ class Val : public NodeBase
m_operands.add(ValNodeOperand(v));
}
List<ValNodeOperand> m_operands;
+
+ // Private use by stdlib deserialization only. Since we know the Vals serialized into stdlib is already
+ // unique, we can just use `this` pointer as the `m_resolvedVal` so we don't need to resolve them again.
+ void _setUnique();
protected:
Val* defaultResolveImpl();
private:
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 0d63e1060..e1835c741 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -117,23 +117,24 @@ class ASTBuilder : public RefObject
friend class SharedASTBuilder;
public:
- template<typename NodeCreateFunc>
- NodeBase* _getOrCreateImpl(ValNodeDesc const& desc, NodeCreateFunc createFunc)
+
+ Val* _getOrCreateImpl(ValNodeDesc const& desc)
{
if (auto found = m_cachedNodes.tryGetValue(desc))
return *found;
- auto node = createFunc();
+ auto node = as<Val>(createByNodeType(desc.type));
+ SLANG_ASSERT(node);
+ for (auto& operand : desc.operands)
+ node->m_operands.add(operand);
+
m_cachedNodes.add(desc, node);
-#ifdef _DEBUG
- _verifyValDescConsistency(dynamicCast<Val>(node), desc);
-#endif
return node;
}
/// A cache for AST nodes that are entirely defined by their node type, with
/// no need for additional state.
- Dictionary<ValNodeDesc, NodeBase*> m_cachedNodes;
+ Dictionary<ValNodeDesc, Val*> m_cachedNodes;
Dictionary<GenericDecl*, List<Val*>> m_cachedGenericDefaultArgs;
@@ -198,10 +199,10 @@ public:
desc.type = T::kType;
addOrAppendToNodeList(desc.operands, args...);
desc.init();
- auto result = (T*)_getOrCreateImpl(desc, [&]()
- {
- return createImpl<T>(args...);
- });
+ auto result = (T*)_getOrCreateImpl(desc);
+#ifdef _DEBUG
+ _verifyValDescConsistency(dynamicCast<Val>(result), desc);
+#endif
return result;
}
@@ -213,7 +214,7 @@ public:
ValNodeDesc desc;
desc.type = T::kType;
desc.init();
- auto result = (T*)_getOrCreateImpl(desc, [this]() { return createImpl<T>(); });
+ auto result = (T*)_getOrCreateImpl(desc);
#ifdef _DEBUG
_verifyValDescConsistency(dynamicCast<Val>(result), desc);
#endif
diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp
index 4384a6df9..c77cf72ed 100644
--- a/source/slang/slang-ast-decl-ref.cpp
+++ b/source/slang/slang-ast-decl-ref.cpp
@@ -166,7 +166,7 @@ Val* LookupDeclRef::tryResolve(SubtypeWitness* newWitness, Type* newLookupSource
case RequirementWitness::Flavor::val:
{
- auto satisfyingVal = requirementWitness.getVal();
+ auto satisfyingVal = requirementWitness.getVal()->resolve();
return satisfyingVal;
}
break;
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 056577eb0..dd30f8ef6 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -10,27 +10,6 @@
namespace Slang {
-
-bool ValNodeDesc::operator==(ValNodeDesc const& that) const
-{
- if (hashCode != that.hashCode) return false;
- if (type != that.type) return false;
- if (operands.getCount() != that.operands.getCount()) return false;
- for (Index i = 0; i < operands.getCount(); ++i)
- {
- // Note: we are comparing the operands directly for identity
- // (pointer equality) rather than doing the `Val`-level
- // equality check.
- //
- // The rationale here is that nodes that will be created
- // via a `NodeDesc` *should* all be going through the
- // deduplication path anyway, as should their operands.
- //
- if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false;
- }
- return true;
-}
-
void ValNodeDesc::init()
{
Hasher hasher;
@@ -77,7 +56,6 @@ Val* Val::resolveImpl()
Val* Val::resolve()
{
auto astBuilder = getCurrentASTBuilder();
-
// If we are not in a proper checking context, just return the previously resolved val.
if (!astBuilder)
return m_resolvedVal? m_resolvedVal : this;
@@ -86,37 +64,13 @@ Val* Val::resolve()
SLANG_ASSERT(as<Val>(m_resolvedVal));
return m_resolvedVal;
}
-
// Update epoch now to avoid infinite recursion.
m_resolvedValEpoch = getCurrentASTBuilder()->getEpoch();
- m_resolvedVal = this;
m_resolvedVal = resolveImpl();
-
- // Check if we are resolved to an existing Val in the AST cache.
- ValNodeDesc newDesc;
- newDesc.type = m_resolvedVal->astNodeType;
- for (auto operand : m_resolvedVal->m_operands)
- {
- if (operand.kind == ValNodeOperandKind::ValNode)
- {
- auto valOperand = as<Val>(operand.values.nodeOperand);
- if (valOperand)
- {
- operand.values.nodeOperand = valOperand->resolve();
- }
- }
- newDesc.operands.add(operand);
- }
- newDesc.init();
-
- NodeBase* existingNode = nullptr;
- if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode))
- m_resolvedVal = as<Val>(existingNode);
-
#ifdef _DEBUG
if (m_resolvedVal->_debugUID > 0 && this->_debugUID < 0)
{
- //SLANG_ASSERT_FAILURE("should not be modifying stdlib vals outside of stdlib checking.");
+ SLANG_ASSERT_FAILURE("should not be modifying stdlib vals outside of stdlib checking.");
}
#endif
return m_resolvedVal;
@@ -132,11 +86,18 @@ ValNodeDesc Val::getDesc()
return desc;
}
+void Val::_setUnique()
+{
+ m_resolvedVal = this;
+ m_resolvedValEpoch = getCurrentASTBuilder()->getEpoch();
+}
+
Val* Val::defaultResolveImpl()
{
// Default resolve implementation is to recursively resolve all operands, and lookup in deduplication cache.
ValNodeDesc newDesc;
newDesc.type = astNodeType;
+ bool diff = false;
for (auto operand : m_operands)
{
if (operand.kind == ValNodeOperandKind::ValNode)
@@ -144,7 +105,12 @@ Val* Val::defaultResolveImpl()
auto valOperand = as<Val>(operand.values.nodeOperand);
if (valOperand)
{
- operand.values.nodeOperand = valOperand->resolve();
+ auto newOperand = valOperand->resolve();
+ if (newOperand != valOperand)
+ {
+ diff = true;
+ operand.values.nodeOperand = newOperand;
+ }
}
}
newDesc.operands.add(operand);
@@ -152,10 +118,10 @@ Val* Val::defaultResolveImpl()
newDesc.init();
auto astBuilder = getCurrentASTBuilder();
- NodeBase* existingNode = nullptr;
+ Val* existingNode = nullptr;
if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode))
- return as<Val>(existingNode);
- return this;
+ return existingNode;
+ return astBuilder->_getOrCreateImpl(newDesc);
}
String Val::toString()
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index b6a5d94ef..4e8a00907 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1181,7 +1181,9 @@ namespace Slang
if (auto nodiffModifier = modifiedType->findModifier<NoDiffModifierVal>())
{
varDecl->type.type = getRemovedModifierType(modifiedType, nodiffModifier);
- addModifier(varDecl, m_astBuilder->getOrCreate<NoDiffModifier>());
+ auto noDiffModifier = m_astBuilder->create<NoDiffModifier>();
+ noDiffModifier->loc = varDecl->loc;
+ addModifier(varDecl, noDiffModifier);
}
}
diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h
index 351b6f519..5ccf9ea54 100644
--- a/source/slang/slang-serialize-ast-type-info.h
+++ b/source/slang/slang-serialize-ast-type-info.h
@@ -43,28 +43,37 @@ struct SerialTypeInfo<SyntaxClass<T>>
template <>
struct SerialTypeInfo<MatrixCoord> : SerialIdentityTypeInfo<MatrixCoord> {};
-inline void serializePointerValue(SerialWriter* writer, Val* ptrValue, SerialIndex* outSerial)
+inline void serializeValPointerValue(SerialWriter* writer, Val* ptrValue, SerialIndex* outSerial)
{
if (ptrValue)
ptrValue = ptrValue->resolve();
*(SerialIndex*)outSerial = writer->addPointer(ptrValue);
}
-inline void deserializePointerValue(SerialReader* reader, const SerialIndex* inSerial, void* outPtr, Val* unusedForResolution)
+inline void deserializeValPointerValue(SerialReader* reader, const SerialIndex* inSerial, void* outPtr)
{
- SLANG_UNUSED(unusedForResolution);
-
auto val = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<Val>();
*(Val**)outPtr = val;
- if (val)
+}
+
+template<typename T>
+struct PtrSerialTypeInfo<T, std::enable_if_t<std::is_base_of_v<Val, T>>>
+{
+ typedef T* NativeType;
+ typedef SerialIndex SerialType;
+ enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) };
+
+ static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial)
{
- SLANG_ASSERT(as<Val>(val));
- PostSerializationFixUp fixup;
- fixup.kind = PostSerializationFixUpKind::ValPtr;
- fixup.addressToModify = outPtr;
- reader->getFixUps().add(fixup);
+ auto ptrValue = *(T**)inNative;
+ serializeValPointerValue(writer, ptrValue, (SerialIndex*)outSerial);
}
-}
+
+ static void toNative(SerialReader* reader, const void* inSerial, void* outNative)
+ {
+ deserializeValPointerValue(reader, (SerialIndex*)inSerial, outNative);
+ }
+};
template <typename T>
struct SerialTypeInfo<DeclRef<T>> : public SerialTypeInfo<DeclRefBase*> {};
@@ -89,9 +98,9 @@ struct SerialTypeInfo<ValNodeOperand>
if (src.kind == ValNodeOperandKind::ConstantValue)
dst.val = src.values.intOperand;
else if (src.kind == ValNodeOperandKind::ValNode)
- serializePointerValue(writer, (Val*)src.values.nodeOperand, (SerialIndex*)&dst.val);
+ serializeValPointerValue(writer, (Val*)src.values.nodeOperand, (SerialIndex*)&dst.val);
else
- serializePointerValue(writer, src.values.nodeOperand, (SerialIndex*)&dst.val);
+ SerialTypeInfo<NodeBase*>::toSerial(writer, &src.values.nodeOperand, (SerialIndex*)&dst.val);
}
static void toNative(SerialReader* reader, const void* serial, void* native)
{
@@ -104,9 +113,9 @@ struct SerialTypeInfo<ValNodeOperand>
if (dst.kind == ValNodeOperandKind::ConstantValue)
dst.values.intOperand = int64_t(src.val);
else if (dst.kind == ValNodeOperandKind::ValNode)
- deserializePointerValue(reader, (SerialIndex*)&src.val, (Val**)&dst.values.nodeOperand, (Val*)nullptr);
+ deserializeValPointerValue(reader, (SerialIndex*)&src.val, (Val**)&dst.values.nodeOperand);
else
- deserializePointerValue(reader, (SerialIndex*)&src.val, &dst.values.nodeOperand, (NodeBase*)nullptr);
+ SerialTypeInfo<NodeBase*>::toNative(reader, (SerialIndex*)&src.val, (NodeBase**)&dst.values.nodeOperand);
}
};
diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp
index 293535b02..8105d32fb 100644
--- a/source/slang/slang-serialize-container.cpp
+++ b/source/slang/slang-serialize-container.cpp
@@ -566,44 +566,9 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
}
else if (Val* val = dynamicCast<Val>(nodeBase))
{
- valUses[val] = List<Val**>();
- }
- }
- }
- // Go through fixup locations and deduplicate Vals.
- // This is needed because we currently the same Val can be serialized multiple times
- // in different modules. If we have a type defined in Module A and used in Module B,
- // then both serialized Module A and Module B will contain a Type Val object that refers to A.
- // When we load B, we should resolve those type references to the existing Type val instead.
- // This step can be avoided if we can run deduplication while deserializing, which
- // requires a different way of handling Val objects.
- for (auto fixup : reader.getFixUps())
- {
- if (fixup.kind == PostSerializationFixUpKind::ValPtr)
- {
- auto list = valUses.tryGetValue(*(Val**)fixup.addressToModify);
- if (list)
- list->add((Val**)fixup.addressToModify);
- }
- }
- SLANG_AST_BUILDER_RAII(astBuilder);
- for (auto& valUseList : valUses)
- {
- auto val = valUseList.key;
- auto desc = val->getDesc();
- astBuilder->m_cachedNodes.tryGetValueOrAdd(desc, val);
- }
- for (auto& valUseList : valUses)
- {
- auto val = valUseList.key;
- auto newVal = val->resolve();
- if (val != newVal)
- {
- astBuilder->m_cachedNodes[val->getDesc()] = newVal;
- for (auto use : valUseList.value)
- {
- if (*use != newVal)
- *use = newVal;
+ val->_setUnique();
+ auto desc = val->getDesc();
+ astBuilder->m_cachedNodes.tryGetValueOrAdd(desc, val);
}
}
}
diff --git a/source/slang/slang-serialize-type-info.h b/source/slang/slang-serialize-type-info.h
index c4b20c5b9..40129b083 100644
--- a/source/slang/slang-serialize-type-info.h
+++ b/source/slang/slang-serialize-type-info.h
@@ -158,24 +158,8 @@ class Val;
// Pointer
-template<typename T, typename sfinae = typename std::enable_if<!IsBaseOf<Val, T>::Value>::type>
-void serializePointerValue(SerialWriter* writer, T* ptrValue, SerialIndex* outSerial)
-{
- static_assert(!IsBaseOf<Val, T>::Value);
- *(SerialIndex*)outSerial = writer->addPointer(ptrValue);
-}
-
-template<typename T, typename sfinae = typename std::enable_if<!IsBaseOf<Val, T>::Value>::type>
-void deserializePointerValue(SerialReader* reader, SerialIndex* inSerial, void* outPtr, T* unusedForResolution)
-{
- static_assert(!IsBaseOf<Val, T>::Value);
-
- SLANG_UNUSED(unusedForResolution);
- *(T**)outPtr = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<T>();
-}
-
-template <typename T>
-struct SerialTypeInfo<T*>
+template <typename T, typename /*sfinaeType*/ = void>
+struct PtrSerialTypeInfo
{
typedef T* NativeType;
typedef SerialIndex SerialType;
@@ -184,15 +168,19 @@ struct SerialTypeInfo<T*>
static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial)
{
auto ptrToWrite = *(T**)inNative;
- serializePointerValue(writer, ptrToWrite, (SerialIndex*)outSerial);
+ static_assert(!IsBaseOf<Val, T>::Value);
+ *(SerialIndex*)outSerial = writer->addPointer(ptrToWrite);
}
static void toNative(SerialReader* reader, const void* inSerial, void* outNative)
{
- deserializePointerValue(reader, (SerialIndex*)inSerial, outNative, (T*)nullptr);
+ *(T**)outNative = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<T>();
}
};
+template<typename T>
+struct SerialTypeInfo<T*> : public PtrSerialTypeInfo<T> {};
+
// RefPtr (pretty much the same as T* - except for native rep)
template <typename T>
struct SerialTypeInfo<RefPtr<T>>
diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h
index ce7bfa87b..3071dc174 100644
--- a/source/slang/slang-serialize.h
+++ b/source/slang/slang-serialize.h
@@ -216,12 +216,6 @@ enum class PostSerializationFixUpKind
ValPtr,
};
-struct PostSerializationFixUp
-{
- PostSerializationFixUpKind kind;
- void* addressToModify;
-};
-
/* This class is the interface used by toNative implementations to recreate a type. */
class SerialReader : public RefObject
{
@@ -251,8 +245,6 @@ public:
/// Get the entries list
const List<const Entry*>& getEntries() const { return m_entries; }
- List<PostSerializationFixUp>& getFixUps() { return m_fixUps; }
-
/// Access the objects list
/// NOTE that if a SerialObject holding a RefObject and needs to be kept in scope, add the RefObject* via addScope
List<SerialPointer>& getObjects() { return m_objects; }
@@ -289,9 +281,7 @@ protected:
SerialExtraObjects m_extraObjects;
SerialObjectFactory* m_objectFactory;
- SerialClasses* m_classes; ///< Information used to deserialize
-
- List<PostSerializationFixUp> m_fixUps;
+ SerialClasses* m_classes; ///< Information used to deserialize
};
// ---------------------------------------------------------------------------
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index ae44e0c70..f6b902c68 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -437,12 +437,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
nodeDesc.type = (ASTNodeType)classInfo.classInfo->m_classId;
nodeDesc.operands.add(ValNodeOperand(declRef));
nodeDesc.init();
- NodeBase* type = astBuilder->_getOrCreateImpl(nodeDesc, [&]()
- {
- auto resultNode = as<DeclRefType>(classInfo.createInstance(astBuilder));
- resultNode->setOperands(declRef);
- return resultNode;
- });
+ NodeBase* type = astBuilder->_getOrCreateImpl(nodeDesc);
if (!type)
{
SLANG_UNEXPECTED("constructor failure");
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 266533874..cb6e88db4 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -311,9 +311,13 @@ SlangResult Session::compileStdLib(slang::CompileStdLibFlags compileFlags)
}
// TODO(JS): Could make this return a SlangResult as opposed to exception
- addBuiltinSource(coreLanguageScope, "core", getCoreLibraryCode());
- addBuiltinSource(hlslLanguageScope, "hlsl", getHLSLLibraryCode());
- addBuiltinSource(autodiffLanguageScope, "diff", getAutodiffLibraryCode());
+ StringBuilder stdLibSrcBuilder;
+ stdLibSrcBuilder
+ << (const char*)getCoreLibraryCode()->getBufferPointer()
+ << (const char*)getHLSLLibraryCode()->getBufferPointer()
+ << (const char*)getAutodiffLibraryCode()->getBufferPointer();
+ auto stdLibSrcBlob = StringBlob::moveCreate(stdLibSrcBuilder.produceString());
+ addBuiltinSource(coreLanguageScope, "core", stdLibSrcBlob);
if (compileFlags & slang::CompileStdLibFlag::WriteDocumentation)
{
@@ -359,6 +363,8 @@ SlangResult Session::compileStdLib(slang::CompileStdLibFlags compileFlags)
SlangResult Session::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes)
{
+ SLANG_PROFILE;
+
if (m_builtinLinkage->mapNameToLoadedModules.getCount())
{
// Already have a StdLib loaded
@@ -373,8 +379,6 @@ SlangResult Session::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes)
// Let's try loading serialized modules and adding them
SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, coreLanguageScope, "core"));
- SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, hlslLanguageScope, "hlsl"));
- SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, autodiffLanguageScope, "diff"));
return SLANG_OK;
}