summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ast-builder.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-04 15:47:39 -0700
committerGitHub <noreply@github.com>2023-08-04 15:47:39 -0700
commita2d90fb275962da84611160f8ddd74d934a68dbd (patch)
tree066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-ast-builder.cpp
parent17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff)
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val. * Update project files * Fix warning. * Fix. * Fix. * Remove `Val::_equalsImplOverride`. * Rmove `Val::_getHashCodeOverride`. * Remove `semanticVisitor` param from `resolve`. * Cleanups. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-builder.cpp')
-rw-r--r--source/slang/slang-ast-builder.cpp281
1 files changed, 142 insertions, 139 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 64a7abd8c..96fb6ac79 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -29,12 +29,6 @@ void SharedASTBuilder::init(Session* session)
// Clear the built in types
memset(m_builtinTypes, 0, sizeof(m_builtinTypes));
- // Create common shared types
- m_errorType = m_astBuilder->create<ErrorType>();
- m_bottomType = m_astBuilder->create<BottomType>();
- m_initializerListType = m_astBuilder->create<InitializerListType>();
- m_overloadedType = m_astBuilder->create<OverloadGroupType>();
-
// We can just iterate over the class pointers.
// NOTE! That this adds the names of the abstract classes too(!)
for (Index i = 0; i < Index(ASTNodeType::CountOf); ++i)
@@ -151,6 +145,31 @@ Type* SharedASTBuilder::getDiffInterfaceType()
return m_diffInterfaceType;
}
+Type* SharedASTBuilder::getErrorType()
+{
+ if (!m_errorType)
+ m_errorType = m_astBuilder->getOrCreate<ErrorType>();
+ return m_errorType;
+}
+Type* SharedASTBuilder::getBottomType()
+{
+ if (!m_bottomType)
+ m_bottomType = m_astBuilder->getOrCreate<BottomType>();
+ return m_bottomType;
+}
+Type* SharedASTBuilder::getInitializerListType()
+{
+ if (!m_initializerListType)
+ m_initializerListType = m_astBuilder->getOrCreate<InitializerListType>();
+ return m_initializerListType;
+}
+Type* SharedASTBuilder::getOverloadedType()
+{
+ if (!m_overloadedType)
+ m_overloadedType = m_astBuilder->getOrCreate<OverloadGroupType>();
+ return m_overloadedType;
+}
+
SharedASTBuilder::~SharedASTBuilder()
{
// Release built in types..
@@ -208,19 +227,28 @@ Decl* SharedASTBuilder::tryFindMagicDecl(const String& name)
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+Index& _getGlobalASTEpochId()
+{
+ static thread_local Index epochId = 1;
+ return epochId;
+}
+
ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name):
m_sharedASTBuilder(sharedASTBuilder),
m_name(name),
m_id(sharedASTBuilder->m_id++),
- m_arena(2048)
+ m_arena(2097152)
{
SLANG_ASSERT(sharedASTBuilder);
+ // Copy Val deduplication map over so we don't create duplicate Vals that are already
+ // existent in the stdlib.
+ m_cachedNodes = sharedASTBuilder->getInnerASTBuilder()->m_cachedNodes;
}
ASTBuilder::ASTBuilder():
m_sharedASTBuilder(nullptr),
m_id(-1),
- m_arena(2048)
+ m_arena(2097152)
{
m_name = "SharedASTBuilder::m_astBuilder";
}
@@ -233,6 +261,25 @@ ASTBuilder::~ASTBuilder()
SLANG_ASSERT(info->m_destructorFunc);
info->m_destructorFunc(node);
}
+ incrementEpoch();
+}
+
+Index ASTBuilder::getEpoch()
+{
+ return _getGlobalASTEpochId();
+}
+
+void ASTBuilder::incrementEpoch()
+{
+ _getGlobalASTEpochId()++;
+}
+
+void ASTBuilder::_verifyValDescConsistency(Val* val, const ValNodeDesc& expectedDesc)
+{
+ if (!val)
+ return;
+ ValNodeDesc descOut = val->getDesc();
+ SLANG_ASSERT(descOut == expectedDesc);
}
NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType)
@@ -256,6 +303,13 @@ Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTy
return rsType;
}
+Type* ASTBuilder::getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const char* magicTypeName)
+{
+ auto declRef = getBuiltinDeclRef(magicTypeName, genericArgs);
+ auto rsType = DeclRefType::create(this, declRef);
+ return rsType;
+}
+
PtrType* ASTBuilder::getPtrType(Type* valueType)
{
return dynamicCast<PtrType>(getPtrType(valueType, "PtrType"));
@@ -292,64 +346,57 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element
{
if (!elementCount)
elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength);
-
- auto result = getOrCreate<ArrayExpressionType>(elementType, elementCount);
- if (!result->declRef.getDecl())
+ if (elementCount->getType() != getIntType())
{
- auto arrayGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ArrayType"));
- auto arrayTypeDecl = arrayGenericDecl->inner;
- auto substitutions = getOrCreateGenericSubstitution(nullptr, arrayGenericDecl, elementType, elementCount);
- result->declRef = getSpecializedDeclRef<Decl>(arrayTypeDecl, substitutions);
+ // Canonicalize constant elementCount to int.
+ if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount))
+ {
+ elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue());
+ }
}
- return result;
+ Val* args[] = {elementType, elementCount};
+ return as<ArrayExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType"));
}
ConstantBufferType* ASTBuilder::getConstantBufferType(Type* elementType)
{
- auto result = getOrCreate<ConstantBufferType>(elementType);
- if (!result->declRef.getDecl())
- {
- auto genericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ConstantBuffer"));
- auto typeDecl = genericDecl->inner;
- auto substitutions = getOrCreateGenericSubstitution(nullptr, genericDecl, elementType);
- result->declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions);
- }
- return result;
+ return as<ConstantBufferType>(getSpecializedBuiltinType(elementType, "ConstantBufferType"));
+}
+
+ParameterBlockType* ASTBuilder::getParameterBlockType(Type* elementType)
+{
+ return as<ParameterBlockType>(getSpecializedBuiltinType(elementType, "ParameterBlockType"));
+}
+
+HLSLStructuredBufferType* ASTBuilder::getStructuredBufferType(Type* elementType)
+{
+ return as<HLSLStructuredBufferType>(getSpecializedBuiltinType(elementType, "HLSLStructuredBufferType"));
+}
+
+SamplerStateType* ASTBuilder::getSamplerStateType()
+{
+ return as<SamplerStateType>(getSpecializedBuiltinType(nullptr, "HLSLStructuredBufferType"));
}
VectorExpressionType* ASTBuilder::getVectorType(
Type* elementType,
IntVal* elementCount)
{
- auto result = getOrCreate<VectorExpressionType>(elementType, elementCount);
- if (!result->declRef.getDecl())
+ // Canonicalize constant elementCount to int.
+ if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount))
{
- auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector"));
- auto vectorTypeDecl = vectorGenericDecl->inner;
- auto substitutions = getOrCreateGenericSubstitution(nullptr, vectorGenericDecl, elementType, elementCount);
- result->declRef = getSpecializedDeclRef<Decl>(vectorTypeDecl, substitutions);
+ elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue());
}
- return result;
+ Val* args[] = { elementType, elementCount };
+ return as<VectorExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType"));
}
DifferentialPairType* ASTBuilder::getDifferentialPairType(
Type* valueType,
Witness* primalIsDifferentialWitness)
{
- auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl("DifferentialPairType"));
-
- auto typeDecl = genericDecl->inner;
-
- auto substitutions = getOrCreateGenericSubstitution(
- nullptr,
- genericDecl,
- valueType,
- primalIsDifferentialWitness);
-
- auto declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions);
- auto rsType = DeclRefType::create(this, declRef);
-
- return as<DifferentialPairType>(rsType);
+ Val* args[] = { valueType, primalIsDifferentialWitness };
+ return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType"));
}
DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl()
@@ -377,20 +424,9 @@ MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier(
: as<HLSLIndicesModifier>(modifier) ? "IndicesType"
: as<HLSLPrimitivesModifier>(modifier) ? "PrimitivesType"
: (SLANG_UNEXPECTED("Unhandled mesh output modifier"), nullptr);
- auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl(declName));
-
- auto typeDecl = genericDecl->inner;
-
- auto substitutions = getOrCreateGenericSubstitution(
- nullptr,
- genericDecl,
- elementType,
- maxElementCount);
- auto declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions);
- auto rsType = DeclRefType::create(this, declRef);
-
- return as<MeshOutputType>(rsType);
+ Val* args[] = { elementType, maxElementCount };
+ return as<MeshOutputType>(getSpecializedBuiltinType(makeArrayView(args), declName));
}
Type* ASTBuilder::getDifferentiableInterfaceType()
@@ -403,13 +439,8 @@ DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va
auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
if (auto genericDecl = as<GenericDecl>(decl))
{
- decl = genericDecl->inner;
- Substitutions* subst = nullptr;
- if (genericArg)
- {
- subst = getOrCreateGenericSubstitution(nullptr, genericDecl, genericArg);
- }
- return getSpecializedDeclRef(decl, subst);
+ auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), makeConstArrayViewSingle(genericArg));
+ return declRef;
}
else
{
@@ -418,6 +449,21 @@ DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va
return makeDeclRef(decl);
}
+DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ArrayView<Val*> genericArgs)
+{
+ auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
+ if (auto genericDecl = as<GenericDecl>(decl))
+ {
+ auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), genericArgs);
+ return declRef;
+ }
+ else
+ {
+ SLANG_ASSERT(!decl && !genericArgs.getCount());
+ }
+ return makeDeclRef(decl);
+}
+
Type* ASTBuilder::getAndType(Type* left, Type* right)
{
auto type = getOrCreate<AndType>(left, right);
@@ -426,9 +472,7 @@ Type* ASTBuilder::getAndType(Type* left, Type* right)
Type* ASTBuilder::getModifiedType(Type* base, Count modifierCount, Val* const* modifiers)
{
- auto type = create<ModifiedType>();
- type->base = base;
- type->modifiers.addRange(modifiers, modifierCount);
+ auto type = getOrCreate<ModifiedType>(base, makeArrayView((Val**)modifiers, modifierCount));
return type;
}
@@ -447,15 +491,16 @@ Val* ASTBuilder::getNoDiffModifierVal()
return getOrCreate<NoDiffModifierVal>();
}
-Type* ASTBuilder::getFuncType(List<Type*> parameters, Type* result)
+FuncType* ASTBuilder::getFuncType(ArrayView<Type*> parameters, Type* result, Type* errorType)
{
- auto errorType = getOrCreate<BottomType>();
+ if (!errorType)
+ errorType = getOrCreate<BottomType>();
return getOrCreate<FuncType>(parameters, result, errorType);
}
-Type* ASTBuilder::getTupleType(List<Type*>& types)
+TupleType* ASTBuilder::getTupleType(List<Type*>& types)
{
- return getOrCreate<TupleType>(types);
+ return getOrCreate<TupleType>(types.getArrayView());
}
TypeType* ASTBuilder::getTypeType(Type* type)
@@ -466,11 +511,11 @@ TypeType* ASTBuilder::getTypeType(Type* type)
TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness(
Type* type)
{
- return getOrCreate<TypeEqualityWitness>(type);
+ return getOrCreate<TypeEqualityWitness>(type, type);
}
-SubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness(
+DeclaredSubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness(
Type* subType,
Type* superType,
DeclRef<Decl> const& declRef)
@@ -517,8 +562,8 @@ top:
// Let's call the intermediate type here `x`, we know that the `b <: c`
// witness is based on witnesses that `b <: x` and `x <: c`:
//
- auto bIsSubtypeOfXWitness = bIsTransitiveSubtypeOfCWitness->subToMid;
- auto xIsSubtypeOfCWitness = bIsTransitiveSubtypeOfCWitness->midToSup;
+ auto bIsSubtypeOfXWitness = bIsTransitiveSubtypeOfCWitness->getSubToMid();
+ auto xIsSubtypeOfCWitness = bIsTransitiveSubtypeOfCWitness->getMidToSup();
// We can recursively call this operation to produce a witness that
// `a <: x`, based on the witnesses we already have for `a <: b` and `b <: x`:
@@ -535,8 +580,8 @@ top:
goto top;
}
- auto aType = aIsSubtypeOfBWitness->sub;
- auto cType = bIsSubtypeOfCWitness->sup;
+ auto aType = aIsSubtypeOfBWitness->getSub();
+ auto cType = bIsSubtypeOfCWitness->getSup();
// If the right-hand side is a conjunction witness for `B <: C`
// of the form `(B <: X)&(B <: Y)`, then we have it that `C = X&Y`
@@ -565,8 +610,8 @@ top:
// the witness `W` that `B <: X&Y&...` as well as the index
// `i` of `C` within the conjunction.
//
- auto bIsSubtypeOfConjunction = bIsSubtypeViaExtraction->conjunctionWitness;
- auto indexOfCInConjunction = bIsSubtypeViaExtraction->indexInConjunction;
+ auto bIsSubtypeOfConjunction = bIsSubtypeViaExtraction->getConjunctionWitness();
+ auto indexOfCInConjunction = bIsSubtypeViaExtraction->getIndexInConjunction();
// We lift the extraction to the outside of the composition, by
// forming a witness for `A <: C` that is of the form
@@ -591,24 +636,14 @@ top:
// formal set of rules for the allowed structure of our witnesses to
// guarantee that our simplifications are sufficient.
- TransitiveSubtypeWitness* transitiveWitness = getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(
+ TransitiveSubtypeWitness* transitiveWitness = getOrCreate<TransitiveSubtypeWitness>(
aType,
cType,
aIsSubtypeOfBWitness,
bIsSubtypeOfCWitness);
- transitiveWitness->sub = aType;
- transitiveWitness->sup = cType;
- transitiveWitness->subToMid = aIsSubtypeOfBWitness;
- transitiveWitness->midToSup = bIsSubtypeOfCWitness;
-
return transitiveWitness;
}
-ThisTypeSubtypeWitness* ASTBuilder::getThisTypeSubtypeWitness(Type* subType, Type* superType)
-{
- return getOrCreate<ThisTypeSubtypeWitness>(subType, superType);
-}
-
SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness(
Type* subType,
Type* superType,
@@ -633,16 +668,11 @@ SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness(
//
// * What if the original witness is transitive?
- auto witness = getOrCreateWithDefaultCtor<ExtractFromConjunctionSubtypeWitness>(
+ auto witness = getOrCreate<ExtractFromConjunctionSubtypeWitness>(
subType,
superType,
conjunctionWitness,
indexOfSuperTypeInConjunction);
-
- witness->sub = subType;
- witness->sup = superType;
- witness->conjunctionWitness = conjunctionWitness;
- witness->indexInConjunction = indexOfSuperTypeInConjunction;
return witness;
}
@@ -662,11 +692,11 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
auto rExtract = as<ExtractFromConjunctionSubtypeWitness>(subIsRWitness);
if(lExtract && rExtract)
{
- if (lExtract->indexInConjunction == 0
- && rExtract->indexInConjunction == 1)
+ if (lExtract->getIndexInConjunction() == 0
+ && rExtract->getIndexInConjunction() == 1)
{
- auto lInner = lExtract->conjunctionWitness;
- auto rInner = rExtract->conjunctionWitness;
+ auto lInner = lExtract->getConjunctionWitness();
+ auto rInner = rExtract->getConjunctionWitness();
if (lInner == rInner)
{
return lInner;
@@ -685,57 +715,30 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
// witness) deeper, so that we have more chances to expose a
// conjunction witness at higher levels.
- auto witness = getOrCreateWithDefaultCtor<ConjunctionSubtypeWitness>(
+ auto witness = getOrCreate<ConjunctionSubtypeWitness>(
sub,
lAndR,
subIsLWitness,
subIsRWitness);
- witness->componentWitnesses[0] = subIsLWitness;
- witness->componentWitnesses[1] = subIsRWitness;
- witness->sub = sub;
- witness->sup = lAndR;
return witness;
}
-bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const
+DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl)
{
- if (hashCode != that.hashCode) return false;
- if(type != that.type) return false;
- if(operands.getCount() != that.operands.getCount()) return false;
- for(Index i = 0; i < operands.getCount(); ++i)
- {
- // Note: we are comparing the operands directly for identity
- // (pointer equality) rather than doing the `Val`-level
- // equality check.
- //
- // The rationale here is that nodes that will be created
- // via a `NodeDesc` *should* all be going through the
- // deduplication path anyway, as should their operands.
- //
- if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false;
- }
- return true;
+ return builder->getMemberDeclRef(parent, decl);
}
-void ASTBuilder::NodeDesc::init()
+
+thread_local ASTBuilder* gCurrentASTBuilder = nullptr;
+
+ASTBuilder* getCurrentASTBuilder()
{
- Hasher hasher;
- hasher.hashValue(Int(type));
- for(Index i = 0; i < operands.getCount(); ++i)
- {
- // Note: we are hashing the raw pointer value rather
- // than the content of the value node. This is done
- // to match the semantics implemented for `==` on
- // `NodeDesc`.
- //
- hasher.hashValue(operands[i].values.nodeOperand);
- }
- hashCode = hasher.getResult();
+ return gCurrentASTBuilder;
}
-DeclRef<Decl> _getSpecializedDeclRef(ASTBuilder* builder, Decl* decl, Substitutions* subst)
+void setCurrentASTBuilder(ASTBuilder* astBuilder)
{
- return builder->getSpecializedDeclRef(decl, subst);
+ gCurrentASTBuilder = astBuilder;
}
} // namespace Slang