summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-09 09:11:23 -0700
committerGitHub <noreply@github.com>2023-08-09 09:11:23 -0700
commitc4615fe0ae7e1849b23e9a96d1453794b0b40e90 (patch)
tree0d8d4eed0c90df9664420737d60749f391c11ffb /source/slang
parent793a29afc9539f893883b5ad8d88639d63f401e0 (diff)
Clean up and improve Val deduplication performance. (#3069)
* Clean up and improve Val deuplication performance. * Fix. * Fix. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-base.cpp2
-rw-r--r--source/slang/slang-ast-base.h25
-rw-r--r--source/slang/slang-ast-builder.cpp8
-rw-r--r--source/slang/slang-ast-builder.h81
-rw-r--r--source/slang/slang-ast-modifier.h2
-rw-r--r--source/slang/slang-ast-type.cpp2
-rw-r--r--source/slang/slang-ast-val.cpp26
-rwxr-xr-xsource/slang/slang-compiler.h1
-rw-r--r--source/slang/slang-parser.cpp6
-rw-r--r--source/slang/slang-serialize-container.cpp3
-rw-r--r--source/slang/slang-stdlib.cpp25
-rw-r--r--source/slang/slang-syntax.cpp13
-rw-r--r--source/slang/slang.cpp4
13 files changed, 126 insertions, 72 deletions
diff --git a/source/slang/slang-ast-base.cpp b/source/slang/slang-ast-base.cpp
index 0ad2bb101..5b4a4ea0c 100644
--- a/source/slang/slang-ast-base.cpp
+++ b/source/slang/slang-ast-base.cpp
@@ -25,7 +25,7 @@ DeclRefBase* Decl::getDefaultDeclRef()
auto astBuilder = getCurrentASTBuilder();
if (astBuilder->getEpoch() != m_defaultDeclRefEpoch || !m_defaultDeclRef)
{
- m_defaultDeclRef = astBuilder->getDirectDeclRef(this);
+ m_defaultDeclRef = astBuilder->getOrCreate<DirectDeclRef>(this);
m_defaultDeclRefEpoch = astBuilder->getEpoch();
}
return m_defaultDeclRef;
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index 399097b3e..0eefbab0f 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -140,11 +140,14 @@ struct ValNodeOperand
ValNodeOperand()
{
- values.nodeOperand = nullptr;
+ values.intOperand = 0;
}
explicit ValNodeOperand(NodeBase* node)
{
+ if constexpr(sizeof(values.nodeOperand) < sizeof(values.intOperand))
+ values.intOperand = 0;
+
if (as<Val>(node))
{
values.nodeOperand = (NodeBase*)node;
@@ -158,11 +161,18 @@ struct ValNodeOperand
}
template<typename T>
- explicit ValNodeOperand(DeclRef<T> declRef) { values.nodeOperand = declRef.declRefBase; kind = ValNodeOperandKind::ValNode; }
+ explicit ValNodeOperand(DeclRef<T> declRef)
+ {
+ if constexpr (sizeof(values.nodeOperand) < sizeof(values.intOperand))
+ values.intOperand = 0;
+ values.nodeOperand = declRef.declRefBase; kind = ValNodeOperandKind::ValNode;
+ }
template<typename T>
explicit ValNodeOperand(T* node)
{
+ if constexpr (sizeof(values.nodeOperand) < sizeof(values.intOperand))
+ values.intOperand = 0;
if constexpr (std::is_base_of<Val, T>::value)
{
values.nodeOperand = (NodeBase*)node;
@@ -192,8 +202,11 @@ struct ValNodeOperand
struct ValNodeDesc
{
+private:
+ HashCode hashCode = 0;
+public:
ASTNodeType type;
- ShortList<ValNodeOperand, 4> operands;
+ ShortList<ValNodeOperand, 8> operands;
inline bool operator==(ValNodeDesc const& that) const
{
@@ -210,14 +223,13 @@ struct ValNodeDesc
// via a `NodeDesc` *should* all be going through the
// deduplication path anyway, as should their operands.
//
- if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false;
+ if (operands[i].values.intOperand != that.operands[i].values.intOperand) return false;
}
return true;
}
HashCode getHashCode() const { return hashCode; }
void init();
-private:
- HashCode hashCode = 0;
+
};
template<int N>
@@ -406,7 +418,6 @@ class Val : public NodeBase
Val* resolveImpl();
Val* resolve();
- ValNodeDesc getDesc();
Val* getOperand(Index index) const
{
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 96fb6ac79..bb4f53433 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -274,14 +274,6 @@ void ASTBuilder::incrementEpoch()
_getGlobalASTEpochId()++;
}
-void ASTBuilder::_verifyValDescConsistency(Val* val, const ValNodeDesc& expectedDesc)
-{
- if (!val)
- return;
- ValNodeDesc descOut = val->getDesc();
- SLANG_ASSERT(descOut == expectedDesc);
-}
-
NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType)
{
const ReflectClassInfo* info = ASTClassInfo::getInfo(nodeType);
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index e1835c741..5c0e74851 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -112,13 +112,55 @@ protected:
Index m_id = 1;
};
+struct ValKey
+{
+ Val* val;
+ HashCode hashCode;
+ ValKey() = default;
+ ValKey(Val* v)
+ {
+ val = v;
+ Hasher hasher;
+ hasher.hashValue(v->astNodeType);
+ for (auto& operand : v->m_operands)
+ hasher.hashValue(operand.values.intOperand);
+ hashCode = hasher.getResult();
+ }
+ bool operator==(ValKey other) const
+ {
+ if (val == other.val) return true;
+ if (hashCode != other.hashCode) return false;
+ if (val->astNodeType != other.val->astNodeType)
+ return false;
+ if (val->m_operands.getCount() != other.val->m_operands.getCount())
+ return false;
+ for (Index i = 0; i < val->m_operands.getCount(); i++)
+ if (val->m_operands[i].values.intOperand != other.val->m_operands[i].values.intOperand)
+ return false;
+ return true;
+ }
+ bool operator==(const ValNodeDesc& desc) const
+ {
+ if (hashCode != desc.getHashCode()) return false;
+ if (val->astNodeType != desc.type)
+ return false;
+ if (val->m_operands.getCount() != desc.operands.getCount())
+ return false;
+ for (Index i = 0; i < val->m_operands.getCount(); i++)
+ if (val->m_operands[i].values.intOperand != desc.operands[i].values.intOperand)
+ return false;
+ return true;
+ }
+ HashCode getHashCode() const { return hashCode; }
+};
+
class ASTBuilder : public RefObject
{
friend class SharedASTBuilder;
public:
- Val* _getOrCreateImpl(ValNodeDesc const& desc)
+ Val* _getOrCreateImpl(ValNodeDesc&& desc)
{
if (auto found = m_cachedNodes.tryGetValue(desc))
return *found;
@@ -127,14 +169,14 @@ public:
SLANG_ASSERT(node);
for (auto& operand : desc.operands)
node->m_operands.add(operand);
-
- m_cachedNodes.add(desc, node);
- return node;
+ auto result = node;
+ m_cachedNodes.add(ValKey(node), _Move(node));
+ return result;
}
/// A cache for AST nodes that are entirely defined by their node type, with
/// no need for additional state.
- Dictionary<ValNodeDesc, Val*> m_cachedNodes;
+ Dictionary<ValKey, Val*> m_cachedNodes;
Dictionary<GenericDecl*, List<Val*>> m_cachedGenericDefaultArgs;
@@ -189,8 +231,6 @@ public:
MemoryArena& getArena() { return m_arena; }
- void _verifyValDescConsistency(Val* val, const ValNodeDesc& expectedDesc);
-
template<typename T, typename ... TArgs>
SLANG_FORCE_INLINE T* getOrCreate(TArgs ... args)
{
@@ -199,10 +239,7 @@ public:
desc.type = T::kType;
addOrAppendToNodeList(desc.operands, args...);
desc.init();
- auto result = (T*)_getOrCreateImpl(desc);
-#ifdef _DEBUG
- _verifyValDescConsistency(dynamicCast<Val>(result), desc);
-#endif
+ auto result = (T*)_getOrCreateImpl(_Move(desc));
return result;
}
@@ -214,10 +251,7 @@ public:
ValNodeDesc desc;
desc.type = T::kType;
desc.init();
- auto result = (T*)_getOrCreateImpl(desc);
-#ifdef _DEBUG
- _verifyValDescConsistency(dynamicCast<Val>(result), desc);
-#endif
+ auto result = (T*)_getOrCreateImpl(_Move(desc));
return result;
}
@@ -237,13 +271,9 @@ public:
}
template<typename T>
- DeclRef<T> getDirectDeclRef(T* decl)
+ DeclRef<T> getDirectDeclRef(T* decl, typename std::enable_if_t<std::is_base_of_v<Decl, T>>* = nullptr)
{
- if (!decl)
- return DeclRef<T>();
-
- auto result = DeclRef<T>(getOrCreate<DirectDeclRef>(decl));
- return result;
+ return DeclRef<T>(decl);
}
template<typename T>
@@ -285,7 +315,7 @@ public:
}
else if (auto directDeclRef = as<DirectDeclRef>(parent.declRefBase))
{
- return DeclRef<T>(getOrCreate<DirectDeclRef>(memberDecl));
+ return makeDeclRef(memberDecl);
}
#if _DEBUG
@@ -331,11 +361,6 @@ public:
LookupDeclRef* getLookupDeclRef(Type* base, SubtypeWitness* subtypeWitness, Decl* declToLookup)
{
- ValNodeDesc desc;
- desc.type = LookupDeclRef::kType;
- desc.operands.add(ValNodeOperand(subtypeWitness));
- desc.operands.add(ValNodeOperand(declToLookup));
- desc.init();
auto result = getOrCreate<LookupDeclRef>(declToLookup, base, subtypeWitness);
return result;
}
@@ -487,7 +512,7 @@ public:
Index getId() { return m_id; }
/// Ctor
- ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name);
+ ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name);
/// Dtor
~ASTBuilder();
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 8e7cc9193..01cf3a24a 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -392,6 +392,8 @@ class MagicTypeModifier : public Modifier
{
SLANG_AST_CLASS(MagicTypeModifier)
+ ASTNodeType magicNodeType = ASTNodeType(-1);
+
/// Modifier has a name so call this magicModifier to disambiguate
String magicName;
uint32_t tag = uint32_t(0);
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 13133a7f8..7ea5e8ed1 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -692,7 +692,7 @@ Val* AndType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet su
(*ioDiff)++;
- auto substType = getCurrentASTBuilder()->getAndType(substLeft, substRight);
+ auto substType = astBuilder->getAndType(substLeft, substRight);
return substType;
}
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index dd30f8ef6..a0ea60625 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -21,7 +21,7 @@ void ValNodeDesc::init()
// to match the semantics implemented for `==` on
// `NodeDesc`.
//
- hasher.hashValue(operands[i].values.nodeOperand);
+ hasher.hashValue(operands[i].values.intOperand);
}
hashCode = hasher.getResult();
}
@@ -59,13 +59,13 @@ Val* Val::resolve()
// If we are not in a proper checking context, just return the previously resolved val.
if (!astBuilder)
return m_resolvedVal? m_resolvedVal : this;
- if (m_resolvedVal && m_resolvedValEpoch == getCurrentASTBuilder()->getEpoch())
+ if (m_resolvedVal && m_resolvedValEpoch == astBuilder->getEpoch())
{
SLANG_ASSERT(as<Val>(m_resolvedVal));
return m_resolvedVal;
}
// Update epoch now to avoid infinite recursion.
- m_resolvedValEpoch = getCurrentASTBuilder()->getEpoch();
+ m_resolvedValEpoch = astBuilder->getEpoch();
m_resolvedVal = resolveImpl();
#ifdef _DEBUG
if (m_resolvedVal->_debugUID > 0 && this->_debugUID < 0)
@@ -76,16 +76,6 @@ Val* Val::resolve()
return m_resolvedVal;
}
-ValNodeDesc Val::getDesc()
-{
- ValNodeDesc desc;
- desc.type = astNodeType;
- for (auto operand : m_operands)
- desc.operands.add(operand);
- desc.init();
- return desc;
-}
-
void Val::_setUnique()
{
m_resolvedVal = this;
@@ -115,13 +105,13 @@ Val* Val::defaultResolveImpl()
}
newDesc.operands.add(operand);
}
+
+ if (!diff)
+ return this;
+
newDesc.init();
auto astBuilder = getCurrentASTBuilder();
-
- Val* existingNode = nullptr;
- if (astBuilder->m_cachedNodes.tryGetValue(newDesc, existingNode))
- return existingNode;
- return astBuilder->_getOrCreateImpl(newDesc);
+ return astBuilder->_getOrCreateImpl(_Move(newDesc));
}
String Val::toString()
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 96456478d..a6231d959 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -3026,6 +3026,7 @@ namespace Slang
/// This AST Builder should only be used for creating AST nodes that are global across requests
/// not doing so could lead to memory being consumed but not used.
ASTBuilder* getGlobalASTBuilder() { return globalAstBuilder; }
+ void finalizeSharedASTBuilder();
RefPtr<ASTBuilder> globalAstBuilder;
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 4448a96e1..1836fd550 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -6659,6 +6659,12 @@ namespace Slang
{
modifier->tag = uint32_t(stringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent()));
}
+ auto classInfo = parser->astBuilder->findClassInfo(getName(parser, modifier->magicName));
+ if (classInfo)
+ {
+ modifier->magicNodeType = ASTNodeType(classInfo->m_classId);
+ }
+ // TODO: print diagnostic if the magic type name doesn't correspond to an actual ASTNodeType.
parser->ReadToken(TokenType::RParent);
return modifier;
diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp
index 8105d32fb..263abf465 100644
--- a/source/slang/slang-serialize-container.cpp
+++ b/source/slang/slang-serialize-container.cpp
@@ -567,8 +567,7 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
else if (Val* val = dynamicCast<Val>(nodeBase))
{
val->_setUnique();
- auto desc = val->getDesc();
- astBuilder->m_cachedNodes.tryGetValueOrAdd(desc, val);
+ astBuilder->m_cachedNodes.tryGetValueOrAdd(ValKey(val), val);
}
}
}
diff --git a/source/slang/slang-stdlib.cpp b/source/slang/slang-stdlib.cpp
index f3f8f325e..65d5cf758 100644
--- a/source/slang/slang-stdlib.cpp
+++ b/source/slang/slang-stdlib.cpp
@@ -112,6 +112,31 @@ namespace Slang
};
+ void Session::finalizeSharedASTBuilder()
+ {
+ // Force creation of all builtin types so we can make sure
+ // they are created by the builtin AST builder instead of
+ // some user linkage's ast builder. This avoid the problem
+ // of storing a reference to these global types that are
+ // owned by a user linkage that gets deleted with the linkage.
+ //
+ globalAstBuilder->getNoneType();
+ globalAstBuilder->getNullPtrType();
+ globalAstBuilder->getBottomType();
+ globalAstBuilder->getErrorType();
+ globalAstBuilder->getInitializerListType();
+ globalAstBuilder->getOverloadedType();
+ globalAstBuilder->getStringType();
+ globalAstBuilder->getEnumTypeType();
+ globalAstBuilder->getDiffInterfaceType();
+ globalAstBuilder->getSharedASTBuilder()->getDynamicType();
+ globalAstBuilder->getSharedASTBuilder()->getDiffInterfaceType();
+ globalAstBuilder->getSharedASTBuilder()->getNativeStringType();
+ for (auto& baseType : kBaseTypes)
+ globalAstBuilder->getBuiltinType(baseType.tag);
+ }
+
+
// Given two base types, we need to be able to compute the cost of converting between them.
ConversionCost getBaseTypeConversionCost(
BaseTypeConversionInfo const& toInfo,
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index f6b902c68..b7e9af43a 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -423,21 +423,20 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
}
else if (auto magicMod = declRef.getDecl()->findModifier<MagicTypeModifier>())
{
+ if (magicMod->magicNodeType == ASTNodeType(-1))
+ {
+ SLANG_UNEXPECTED("unhandled type");
+ }
// Always create builtin types in global AST builder.
if (astBuilder->getSharedASTBuilder()->getInnerASTBuilder() != astBuilder)
return DeclRefType::create(astBuilder->getSharedASTBuilder()->getInnerASTBuilder(), declRef);
declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef);
- auto classInfo = astBuilder->findSyntaxClass(magicMod->magicName.getUnownedSlice());
- if (!classInfo.classInfo)
- {
- SLANG_UNEXPECTED("unhandled type");
- }
ValNodeDesc nodeDesc = {};
- nodeDesc.type = (ASTNodeType)classInfo.classInfo->m_classId;
+ nodeDesc.type = magicMod->magicNodeType;
nodeDesc.operands.add(ValNodeOperand(declRef));
nodeDesc.init();
- NodeBase* type = astBuilder->_getOrCreateImpl(nodeDesc);
+ NodeBase* type = astBuilder->_getOrCreateImpl(_Move(nodeDesc));
if (!type)
{
SLANG_UNEXPECTED("constructor failure");
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index cb6e88db4..1edf62a38 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -358,6 +358,8 @@ SlangResult Session::compileStdLib(slang::CompileStdLibFlags compileFlags)
}
}
+ finalizeSharedASTBuilder();
+
return SLANG_OK;
}
@@ -379,6 +381,8 @@ SlangResult Session::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes)
// Let's try loading serialized modules and adding them
SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, coreLanguageScope, "core"));
+
+ finalizeSharedASTBuilder();
return SLANG_OK;
}