diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-07 15:00:38 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-07 15:00:38 -0700 |
| commit | 9eb6a84285c1597d723be13924a7ad2991cf717f (patch) | |
| tree | ad4358fb9dcbbd4b561670d02671859a217ad14a /source/slang | |
| parent | 9ef9cc00d98d1775f0ad86efd246ca1605b3b3e4 (diff) | |
Fix `Val` deduplication bug. (#3050)
* Fix `Val` deduplication bug.
* Fix
* Concat stdlib files into a single module.
* Remove unnecessary logic in `resolve`.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ast-base.h | 26 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.h | 25 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl-ref.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 68 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-serialize-ast-type-info.h | 39 | ||||
| -rw-r--r-- | source/slang/slang-serialize-container.cpp | 41 | ||||
| -rw-r--r-- | source/slang/slang-serialize-type-info.h | 28 | ||||
| -rw-r--r-- | source/slang/slang-serialize.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 14 |
11 files changed, 103 insertions, 163 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index d8f4c8c6c..399097b3e 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -46,8 +46,6 @@ class NodeBase /// The actual type is set when constructed on the ASTBuilder. ASTNodeType astNodeType = ASTNodeType(-1); - // Handy when debugging, shouldn't be checked in though! - // virtual ~NodeBase() {} #ifdef _DEBUG SLANG_UNREFLECTED int32_t _debugUID = 0; #endif @@ -197,7 +195,25 @@ struct ValNodeDesc ASTNodeType type; ShortList<ValNodeOperand, 4> operands; - bool operator==(ValNodeDesc const& that) const; + inline bool operator==(ValNodeDesc const& that) const + { + if (hashCode != that.hashCode) return false; + if (type != that.type) return false; + if (operands.getCount() != that.operands.getCount()) return false; + for (Index i = 0; i < operands.getCount(); ++i) + { + // Note: we are comparing the operands directly for identity + // (pointer equality) rather than doing the `Val`-level + // equality check. + // + // The rationale here is that nodes that will be created + // via a `NodeDesc` *should* all be going through the + // deduplication path anyway, as should their operands. + // + if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; + } + return true; + } HashCode getHashCode() const { return hashCode; } void init(); private: @@ -430,6 +446,10 @@ class Val : public NodeBase m_operands.add(ValNodeOperand(v)); } List<ValNodeOperand> m_operands; + + // Private use by stdlib deserialization only. Since we know the Vals serialized into stdlib is already + // unique, we can just use `this` pointer as the `m_resolvedVal` so we don't need to resolve them again. + void _setUnique(); protected: Val* defaultResolveImpl(); private: diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 0d63e1060..e1835c741 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -117,23 +117,24 @@ class ASTBuilder : public RefObject friend class SharedASTBuilder; public: - template<typename NodeCreateFunc> - NodeBase* _getOrCreateImpl(ValNodeDesc const& desc, NodeCreateFunc createFunc) + + Val* _getOrCreateImpl(ValNodeDesc const& desc) { if (auto found = m_cachedNodes.tryGetValue(desc)) return *found; - auto node = createFunc(); + auto node = as<Val>(createByNodeType(desc.type)); + SLANG_ASSERT(node); + for (auto& operand : desc.operands) + node->m_operands.add(operand); + m_cachedNodes.add(desc, node); -#ifdef _DEBUG - _verifyValDescConsistency(dynamicCast<Val>(node), desc); -#endif return node; } /// A cache for AST nodes that are entirely defined by their node type, with /// no need for additional state. - Dictionary<ValNodeDesc, NodeBase*> m_cachedNodes; + Dictionary<ValNodeDesc, Val*> m_cachedNodes; Dictionary<GenericDecl*, List<Val*>> m_cachedGenericDefaultArgs; @@ -198,10 +199,10 @@ public: desc.type = T::kType; addOrAppendToNodeList(desc.operands, args...); desc.init(); - auto result = (T*)_getOrCreateImpl(desc, [&]() - { - return createImpl<T>(args...); - }); + auto result = (T*)_getOrCreateImpl(desc); +#ifdef _DEBUG + _verifyValDescConsistency(dynamicCast<Val>(result), desc); +#endif return result; } @@ -213,7 +214,7 @@ public: ValNodeDesc desc; desc.type = T::kType; desc.init(); - auto result = (T*)_getOrCreateImpl(desc, [this]() { return createImpl<T>(); }); + auto result = (T*)_getOrCreateImpl(desc); #ifdef _DEBUG _verifyValDescConsistency(dynamicCast<Val>(result), desc); #endif diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp index 4384a6df9..c77cf72ed 100644 --- a/source/slang/slang-ast-decl-ref.cpp +++ b/source/slang/slang-ast-decl-ref.cpp @@ -166,7 +166,7 @@ Val* LookupDeclRef::tryResolve(SubtypeWitness* newWitness, Type* newLookupSource case RequirementWitness::Flavor::val: { - auto satisfyingVal = requirementWitness.getVal(); + auto satisfyingVal = requirementWitness.getVal()->resolve(); return satisfyingVal; } break; diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 056577eb0..dd30f8ef6 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -10,27 +10,6 @@ namespace Slang { - -bool ValNodeDesc::operator==(ValNodeDesc const& that) const -{ - if (hashCode != that.hashCode) return false; - if (type != that.type) return false; - if (operands.getCount() != that.operands.getCount()) return false; - for (Index i = 0; i < operands.getCount(); ++i) - { - // Note: we are comparing the operands directly for identity - // (pointer equality) rather than doing the `Val`-level - // equality check. - // - // The rationale here is that nodes that will be created - // via a `NodeDesc` *should* all be going through the - // deduplication path anyway, as should their operands. - // - if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; - } - return true; -} - void ValNodeDesc::init() { Hasher hasher; @@ -77,7 +56,6 @@ Val* Val::resolveImpl() Val* Val::resolve() { auto astBuilder = getCurrentASTBuilder(); - // If we are not in a proper checking context, just return the previously resolved val. if (!astBuilder) return m_resolvedVal? m_resolvedVal : this; @@ -86,37 +64,13 @@ Val* Val::resolve() SLANG_ASSERT(as<Val>(m_resolvedVal)); return m_resolvedVal; } - // Update epoch now to avoid infinite recursion. m_resolvedValEpoch = getCurrentASTBuilder()->getEpoch(); - m_resolvedVal = this; m_resolvedVal = resolveImpl(); - - // Check if we are resolved to an existing Val in the AST cache. - ValNodeDesc newDesc; - newDesc.type = m_resolvedVal->astNodeType; - for (auto operand : m_resolvedVal->m_operands) - { - if (operand.kind == ValNodeOperandKind::ValNode) - { - auto valOperand = as<Val>(operand.values.nodeOperand); - if (valOperand) - { - operand.values.nodeOperand = valOperand->resolve(); - } - } - newDesc.operands.add(operand); - } - newDesc.init(); - - NodeBase* existingNode = nullptr; - if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode)) - m_resolvedVal = as<Val>(existingNode); - #ifdef _DEBUG if (m_resolvedVal->_debugUID > 0 && this->_debugUID < 0) { - //SLANG_ASSERT_FAILURE("should not be modifying stdlib vals outside of stdlib checking."); + SLANG_ASSERT_FAILURE("should not be modifying stdlib vals outside of stdlib checking."); } #endif return m_resolvedVal; @@ -132,11 +86,18 @@ ValNodeDesc Val::getDesc() return desc; } +void Val::_setUnique() +{ + m_resolvedVal = this; + m_resolvedValEpoch = getCurrentASTBuilder()->getEpoch(); +} + Val* Val::defaultResolveImpl() { // Default resolve implementation is to recursively resolve all operands, and lookup in deduplication cache. ValNodeDesc newDesc; newDesc.type = astNodeType; + bool diff = false; for (auto operand : m_operands) { if (operand.kind == ValNodeOperandKind::ValNode) @@ -144,7 +105,12 @@ Val* Val::defaultResolveImpl() auto valOperand = as<Val>(operand.values.nodeOperand); if (valOperand) { - operand.values.nodeOperand = valOperand->resolve(); + auto newOperand = valOperand->resolve(); + if (newOperand != valOperand) + { + diff = true; + operand.values.nodeOperand = newOperand; + } } } newDesc.operands.add(operand); @@ -152,10 +118,10 @@ Val* Val::defaultResolveImpl() newDesc.init(); auto astBuilder = getCurrentASTBuilder(); - NodeBase* existingNode = nullptr; + Val* existingNode = nullptr; if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode)) - return as<Val>(existingNode); - return this; + return existingNode; + return astBuilder->_getOrCreateImpl(newDesc); } String Val::toString() diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index b6a5d94ef..4e8a00907 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1181,7 +1181,9 @@ namespace Slang if (auto nodiffModifier = modifiedType->findModifier<NoDiffModifierVal>()) { varDecl->type.type = getRemovedModifierType(modifiedType, nodiffModifier); - addModifier(varDecl, m_astBuilder->getOrCreate<NoDiffModifier>()); + auto noDiffModifier = m_astBuilder->create<NoDiffModifier>(); + noDiffModifier->loc = varDecl->loc; + addModifier(varDecl, noDiffModifier); } } diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h index 351b6f519..5ccf9ea54 100644 --- a/source/slang/slang-serialize-ast-type-info.h +++ b/source/slang/slang-serialize-ast-type-info.h @@ -43,28 +43,37 @@ struct SerialTypeInfo<SyntaxClass<T>> template <> struct SerialTypeInfo<MatrixCoord> : SerialIdentityTypeInfo<MatrixCoord> {}; -inline void serializePointerValue(SerialWriter* writer, Val* ptrValue, SerialIndex* outSerial) +inline void serializeValPointerValue(SerialWriter* writer, Val* ptrValue, SerialIndex* outSerial) { if (ptrValue) ptrValue = ptrValue->resolve(); *(SerialIndex*)outSerial = writer->addPointer(ptrValue); } -inline void deserializePointerValue(SerialReader* reader, const SerialIndex* inSerial, void* outPtr, Val* unusedForResolution) +inline void deserializeValPointerValue(SerialReader* reader, const SerialIndex* inSerial, void* outPtr) { - SLANG_UNUSED(unusedForResolution); - auto val = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<Val>(); *(Val**)outPtr = val; - if (val) +} + +template<typename T> +struct PtrSerialTypeInfo<T, std::enable_if_t<std::is_base_of_v<Val, T>>> +{ + typedef T* NativeType; + typedef SerialIndex SerialType; + enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) }; + + static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) { - SLANG_ASSERT(as<Val>(val)); - PostSerializationFixUp fixup; - fixup.kind = PostSerializationFixUpKind::ValPtr; - fixup.addressToModify = outPtr; - reader->getFixUps().add(fixup); + auto ptrValue = *(T**)inNative; + serializeValPointerValue(writer, ptrValue, (SerialIndex*)outSerial); } -} + + static void toNative(SerialReader* reader, const void* inSerial, void* outNative) + { + deserializeValPointerValue(reader, (SerialIndex*)inSerial, outNative); + } +}; template <typename T> struct SerialTypeInfo<DeclRef<T>> : public SerialTypeInfo<DeclRefBase*> {}; @@ -89,9 +98,9 @@ struct SerialTypeInfo<ValNodeOperand> if (src.kind == ValNodeOperandKind::ConstantValue) dst.val = src.values.intOperand; else if (src.kind == ValNodeOperandKind::ValNode) - serializePointerValue(writer, (Val*)src.values.nodeOperand, (SerialIndex*)&dst.val); + serializeValPointerValue(writer, (Val*)src.values.nodeOperand, (SerialIndex*)&dst.val); else - serializePointerValue(writer, src.values.nodeOperand, (SerialIndex*)&dst.val); + SerialTypeInfo<NodeBase*>::toSerial(writer, &src.values.nodeOperand, (SerialIndex*)&dst.val); } static void toNative(SerialReader* reader, const void* serial, void* native) { @@ -104,9 +113,9 @@ struct SerialTypeInfo<ValNodeOperand> if (dst.kind == ValNodeOperandKind::ConstantValue) dst.values.intOperand = int64_t(src.val); else if (dst.kind == ValNodeOperandKind::ValNode) - deserializePointerValue(reader, (SerialIndex*)&src.val, (Val**)&dst.values.nodeOperand, (Val*)nullptr); + deserializeValPointerValue(reader, (SerialIndex*)&src.val, (Val**)&dst.values.nodeOperand); else - deserializePointerValue(reader, (SerialIndex*)&src.val, &dst.values.nodeOperand, (NodeBase*)nullptr); + SerialTypeInfo<NodeBase*>::toNative(reader, (SerialIndex*)&src.val, (NodeBase**)&dst.values.nodeOperand); } }; diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp index 293535b02..8105d32fb 100644 --- a/source/slang/slang-serialize-container.cpp +++ b/source/slang/slang-serialize-container.cpp @@ -566,44 +566,9 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( } else if (Val* val = dynamicCast<Val>(nodeBase)) { - valUses[val] = List<Val**>(); - } - } - } - // Go through fixup locations and deduplicate Vals. - // This is needed because we currently the same Val can be serialized multiple times - // in different modules. If we have a type defined in Module A and used in Module B, - // then both serialized Module A and Module B will contain a Type Val object that refers to A. - // When we load B, we should resolve those type references to the existing Type val instead. - // This step can be avoided if we can run deduplication while deserializing, which - // requires a different way of handling Val objects. - for (auto fixup : reader.getFixUps()) - { - if (fixup.kind == PostSerializationFixUpKind::ValPtr) - { - auto list = valUses.tryGetValue(*(Val**)fixup.addressToModify); - if (list) - list->add((Val**)fixup.addressToModify); - } - } - SLANG_AST_BUILDER_RAII(astBuilder); - for (auto& valUseList : valUses) - { - auto val = valUseList.key; - auto desc = val->getDesc(); - astBuilder->m_cachedNodes.tryGetValueOrAdd(desc, val); - } - for (auto& valUseList : valUses) - { - auto val = valUseList.key; - auto newVal = val->resolve(); - if (val != newVal) - { - astBuilder->m_cachedNodes[val->getDesc()] = newVal; - for (auto use : valUseList.value) - { - if (*use != newVal) - *use = newVal; + val->_setUnique(); + auto desc = val->getDesc(); + astBuilder->m_cachedNodes.tryGetValueOrAdd(desc, val); } } } diff --git a/source/slang/slang-serialize-type-info.h b/source/slang/slang-serialize-type-info.h index c4b20c5b9..40129b083 100644 --- a/source/slang/slang-serialize-type-info.h +++ b/source/slang/slang-serialize-type-info.h @@ -158,24 +158,8 @@ class Val; // Pointer -template<typename T, typename sfinae = typename std::enable_if<!IsBaseOf<Val, T>::Value>::type> -void serializePointerValue(SerialWriter* writer, T* ptrValue, SerialIndex* outSerial) -{ - static_assert(!IsBaseOf<Val, T>::Value); - *(SerialIndex*)outSerial = writer->addPointer(ptrValue); -} - -template<typename T, typename sfinae = typename std::enable_if<!IsBaseOf<Val, T>::Value>::type> -void deserializePointerValue(SerialReader* reader, SerialIndex* inSerial, void* outPtr, T* unusedForResolution) -{ - static_assert(!IsBaseOf<Val, T>::Value); - - SLANG_UNUSED(unusedForResolution); - *(T**)outPtr = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<T>(); -} - -template <typename T> -struct SerialTypeInfo<T*> +template <typename T, typename /*sfinaeType*/ = void> +struct PtrSerialTypeInfo { typedef T* NativeType; typedef SerialIndex SerialType; @@ -184,15 +168,19 @@ struct SerialTypeInfo<T*> static void toSerial(SerialWriter* writer, const void* inNative, void* outSerial) { auto ptrToWrite = *(T**)inNative; - serializePointerValue(writer, ptrToWrite, (SerialIndex*)outSerial); + static_assert(!IsBaseOf<Val, T>::Value); + *(SerialIndex*)outSerial = writer->addPointer(ptrToWrite); } static void toNative(SerialReader* reader, const void* inSerial, void* outNative) { - deserializePointerValue(reader, (SerialIndex*)inSerial, outNative, (T*)nullptr); + *(T**)outNative = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<T>(); } }; +template<typename T> +struct SerialTypeInfo<T*> : public PtrSerialTypeInfo<T> {}; + // RefPtr (pretty much the same as T* - except for native rep) template <typename T> struct SerialTypeInfo<RefPtr<T>> diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h index ce7bfa87b..3071dc174 100644 --- a/source/slang/slang-serialize.h +++ b/source/slang/slang-serialize.h @@ -216,12 +216,6 @@ enum class PostSerializationFixUpKind ValPtr, }; -struct PostSerializationFixUp -{ - PostSerializationFixUpKind kind; - void* addressToModify; -}; - /* This class is the interface used by toNative implementations to recreate a type. */ class SerialReader : public RefObject { @@ -251,8 +245,6 @@ public: /// Get the entries list const List<const Entry*>& getEntries() const { return m_entries; } - List<PostSerializationFixUp>& getFixUps() { return m_fixUps; } - /// Access the objects list /// NOTE that if a SerialObject holding a RefObject and needs to be kept in scope, add the RefObject* via addScope List<SerialPointer>& getObjects() { return m_objects; } @@ -289,9 +281,7 @@ protected: SerialExtraObjects m_extraObjects; SerialObjectFactory* m_objectFactory; - SerialClasses* m_classes; ///< Information used to deserialize - - List<PostSerializationFixUp> m_fixUps; + SerialClasses* m_classes; ///< Information used to deserialize }; // --------------------------------------------------------------------------- diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index ae44e0c70..f6b902c68 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -437,12 +437,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt nodeDesc.type = (ASTNodeType)classInfo.classInfo->m_classId; nodeDesc.operands.add(ValNodeOperand(declRef)); nodeDesc.init(); - NodeBase* type = astBuilder->_getOrCreateImpl(nodeDesc, [&]() - { - auto resultNode = as<DeclRefType>(classInfo.createInstance(astBuilder)); - resultNode->setOperands(declRef); - return resultNode; - }); + NodeBase* type = astBuilder->_getOrCreateImpl(nodeDesc); if (!type) { SLANG_UNEXPECTED("constructor failure"); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 266533874..cb6e88db4 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -311,9 +311,13 @@ SlangResult Session::compileStdLib(slang::CompileStdLibFlags compileFlags) } // TODO(JS): Could make this return a SlangResult as opposed to exception - addBuiltinSource(coreLanguageScope, "core", getCoreLibraryCode()); - addBuiltinSource(hlslLanguageScope, "hlsl", getHLSLLibraryCode()); - addBuiltinSource(autodiffLanguageScope, "diff", getAutodiffLibraryCode()); + StringBuilder stdLibSrcBuilder; + stdLibSrcBuilder + << (const char*)getCoreLibraryCode()->getBufferPointer() + << (const char*)getHLSLLibraryCode()->getBufferPointer() + << (const char*)getAutodiffLibraryCode()->getBufferPointer(); + auto stdLibSrcBlob = StringBlob::moveCreate(stdLibSrcBuilder.produceString()); + addBuiltinSource(coreLanguageScope, "core", stdLibSrcBlob); if (compileFlags & slang::CompileStdLibFlag::WriteDocumentation) { @@ -359,6 +363,8 @@ SlangResult Session::compileStdLib(slang::CompileStdLibFlags compileFlags) SlangResult Session::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes) { + SLANG_PROFILE; + if (m_builtinLinkage->mapNameToLoadedModules.getCount()) { // Already have a StdLib loaded @@ -373,8 +379,6 @@ SlangResult Session::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes) // Let's try loading serialized modules and adding them SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, coreLanguageScope, "core")); - SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, hlslLanguageScope, "hlsl")); - SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, autodiffLanguageScope, "diff")); return SLANG_OK; } |
