diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-04 15:47:39 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-04 15:47:39 -0700 |
| commit | a2d90fb275962da84611160f8ddd74d934a68dbd (patch) | |
| tree | 066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-ast-builder.cpp | |
| parent | 17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff) | |
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val.
* Update project files
* Fix warning.
* Fix.
* Fix.
* Remove `Val::_equalsImplOverride`.
* Rmove `Val::_getHashCodeOverride`.
* Remove `semanticVisitor` param from `resolve`.
* Cleanups.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-builder.cpp')
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 281 |
1 files changed, 142 insertions, 139 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 64a7abd8c..96fb6ac79 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -29,12 +29,6 @@ void SharedASTBuilder::init(Session* session) // Clear the built in types memset(m_builtinTypes, 0, sizeof(m_builtinTypes)); - // Create common shared types - m_errorType = m_astBuilder->create<ErrorType>(); - m_bottomType = m_astBuilder->create<BottomType>(); - m_initializerListType = m_astBuilder->create<InitializerListType>(); - m_overloadedType = m_astBuilder->create<OverloadGroupType>(); - // We can just iterate over the class pointers. // NOTE! That this adds the names of the abstract classes too(!) for (Index i = 0; i < Index(ASTNodeType::CountOf); ++i) @@ -151,6 +145,31 @@ Type* SharedASTBuilder::getDiffInterfaceType() return m_diffInterfaceType; } +Type* SharedASTBuilder::getErrorType() +{ + if (!m_errorType) + m_errorType = m_astBuilder->getOrCreate<ErrorType>(); + return m_errorType; +} +Type* SharedASTBuilder::getBottomType() +{ + if (!m_bottomType) + m_bottomType = m_astBuilder->getOrCreate<BottomType>(); + return m_bottomType; +} +Type* SharedASTBuilder::getInitializerListType() +{ + if (!m_initializerListType) + m_initializerListType = m_astBuilder->getOrCreate<InitializerListType>(); + return m_initializerListType; +} +Type* SharedASTBuilder::getOverloadedType() +{ + if (!m_overloadedType) + m_overloadedType = m_astBuilder->getOrCreate<OverloadGroupType>(); + return m_overloadedType; +} + SharedASTBuilder::~SharedASTBuilder() { // Release built in types.. @@ -208,19 +227,28 @@ Decl* SharedASTBuilder::tryFindMagicDecl(const String& name) // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +Index& _getGlobalASTEpochId() +{ + static thread_local Index epochId = 1; + return epochId; +} + ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name): m_sharedASTBuilder(sharedASTBuilder), m_name(name), m_id(sharedASTBuilder->m_id++), - m_arena(2048) + m_arena(2097152) { SLANG_ASSERT(sharedASTBuilder); + // Copy Val deduplication map over so we don't create duplicate Vals that are already + // existent in the stdlib. + m_cachedNodes = sharedASTBuilder->getInnerASTBuilder()->m_cachedNodes; } ASTBuilder::ASTBuilder(): m_sharedASTBuilder(nullptr), m_id(-1), - m_arena(2048) + m_arena(2097152) { m_name = "SharedASTBuilder::m_astBuilder"; } @@ -233,6 +261,25 @@ ASTBuilder::~ASTBuilder() SLANG_ASSERT(info->m_destructorFunc); info->m_destructorFunc(node); } + incrementEpoch(); +} + +Index ASTBuilder::getEpoch() +{ + return _getGlobalASTEpochId(); +} + +void ASTBuilder::incrementEpoch() +{ + _getGlobalASTEpochId()++; +} + +void ASTBuilder::_verifyValDescConsistency(Val* val, const ValNodeDesc& expectedDesc) +{ + if (!val) + return; + ValNodeDesc descOut = val->getDesc(); + SLANG_ASSERT(descOut == expectedDesc); } NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType) @@ -256,6 +303,13 @@ Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTy return rsType; } +Type* ASTBuilder::getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const char* magicTypeName) +{ + auto declRef = getBuiltinDeclRef(magicTypeName, genericArgs); + auto rsType = DeclRefType::create(this, declRef); + return rsType; +} + PtrType* ASTBuilder::getPtrType(Type* valueType) { return dynamicCast<PtrType>(getPtrType(valueType, "PtrType")); @@ -292,64 +346,57 @@ ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* element { if (!elementCount) elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength); - - auto result = getOrCreate<ArrayExpressionType>(elementType, elementCount); - if (!result->declRef.getDecl()) + if (elementCount->getType() != getIntType()) { - auto arrayGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ArrayType")); - auto arrayTypeDecl = arrayGenericDecl->inner; - auto substitutions = getOrCreateGenericSubstitution(nullptr, arrayGenericDecl, elementType, elementCount); - result->declRef = getSpecializedDeclRef<Decl>(arrayTypeDecl, substitutions); + // Canonicalize constant elementCount to int. + if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount)) + { + elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue()); + } } - return result; + Val* args[] = {elementType, elementCount}; + return as<ArrayExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType")); } ConstantBufferType* ASTBuilder::getConstantBufferType(Type* elementType) { - auto result = getOrCreate<ConstantBufferType>(elementType); - if (!result->declRef.getDecl()) - { - auto genericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("ConstantBuffer")); - auto typeDecl = genericDecl->inner; - auto substitutions = getOrCreateGenericSubstitution(nullptr, genericDecl, elementType); - result->declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions); - } - return result; + return as<ConstantBufferType>(getSpecializedBuiltinType(elementType, "ConstantBufferType")); +} + +ParameterBlockType* ASTBuilder::getParameterBlockType(Type* elementType) +{ + return as<ParameterBlockType>(getSpecializedBuiltinType(elementType, "ParameterBlockType")); +} + +HLSLStructuredBufferType* ASTBuilder::getStructuredBufferType(Type* elementType) +{ + return as<HLSLStructuredBufferType>(getSpecializedBuiltinType(elementType, "HLSLStructuredBufferType")); +} + +SamplerStateType* ASTBuilder::getSamplerStateType() +{ + return as<SamplerStateType>(getSpecializedBuiltinType(nullptr, "HLSLStructuredBufferType")); } VectorExpressionType* ASTBuilder::getVectorType( Type* elementType, IntVal* elementCount) { - auto result = getOrCreate<VectorExpressionType>(elementType, elementCount); - if (!result->declRef.getDecl()) + // Canonicalize constant elementCount to int. + if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount)) { - auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector")); - auto vectorTypeDecl = vectorGenericDecl->inner; - auto substitutions = getOrCreateGenericSubstitution(nullptr, vectorGenericDecl, elementType, elementCount); - result->declRef = getSpecializedDeclRef<Decl>(vectorTypeDecl, substitutions); + elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue()); } - return result; + Val* args[] = { elementType, elementCount }; + return as<VectorExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType")); } DifferentialPairType* ASTBuilder::getDifferentialPairType( Type* valueType, Witness* primalIsDifferentialWitness) { - auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl("DifferentialPairType")); - - auto typeDecl = genericDecl->inner; - - auto substitutions = getOrCreateGenericSubstitution( - nullptr, - genericDecl, - valueType, - primalIsDifferentialWitness); - - auto declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions); - auto rsType = DeclRefType::create(this, declRef); - - return as<DifferentialPairType>(rsType); + Val* args[] = { valueType, primalIsDifferentialWitness }; + return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType")); } DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl() @@ -377,20 +424,9 @@ MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier( : as<HLSLIndicesModifier>(modifier) ? "IndicesType" : as<HLSLPrimitivesModifier>(modifier) ? "PrimitivesType" : (SLANG_UNEXPECTED("Unhandled mesh output modifier"), nullptr); - auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl(declName)); - - auto typeDecl = genericDecl->inner; - - auto substitutions = getOrCreateGenericSubstitution( - nullptr, - genericDecl, - elementType, - maxElementCount); - auto declRef = getSpecializedDeclRef<Decl>(typeDecl, substitutions); - auto rsType = DeclRefType::create(this, declRef); - - return as<MeshOutputType>(rsType); + Val* args[] = { elementType, maxElementCount }; + return as<MeshOutputType>(getSpecializedBuiltinType(makeArrayView(args), declName)); } Type* ASTBuilder::getDifferentiableInterfaceType() @@ -403,13 +439,8 @@ DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); if (auto genericDecl = as<GenericDecl>(decl)) { - decl = genericDecl->inner; - Substitutions* subst = nullptr; - if (genericArg) - { - subst = getOrCreateGenericSubstitution(nullptr, genericDecl, genericArg); - } - return getSpecializedDeclRef(decl, subst); + auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), makeConstArrayViewSingle(genericArg)); + return declRef; } else { @@ -418,6 +449,21 @@ DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Va return makeDeclRef(decl); } +DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ArrayView<Val*> genericArgs) +{ + auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); + if (auto genericDecl = as<GenericDecl>(decl)) + { + auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), genericArgs); + return declRef; + } + else + { + SLANG_ASSERT(!decl && !genericArgs.getCount()); + } + return makeDeclRef(decl); +} + Type* ASTBuilder::getAndType(Type* left, Type* right) { auto type = getOrCreate<AndType>(left, right); @@ -426,9 +472,7 @@ Type* ASTBuilder::getAndType(Type* left, Type* right) Type* ASTBuilder::getModifiedType(Type* base, Count modifierCount, Val* const* modifiers) { - auto type = create<ModifiedType>(); - type->base = base; - type->modifiers.addRange(modifiers, modifierCount); + auto type = getOrCreate<ModifiedType>(base, makeArrayView((Val**)modifiers, modifierCount)); return type; } @@ -447,15 +491,16 @@ Val* ASTBuilder::getNoDiffModifierVal() return getOrCreate<NoDiffModifierVal>(); } -Type* ASTBuilder::getFuncType(List<Type*> parameters, Type* result) +FuncType* ASTBuilder::getFuncType(ArrayView<Type*> parameters, Type* result, Type* errorType) { - auto errorType = getOrCreate<BottomType>(); + if (!errorType) + errorType = getOrCreate<BottomType>(); return getOrCreate<FuncType>(parameters, result, errorType); } -Type* ASTBuilder::getTupleType(List<Type*>& types) +TupleType* ASTBuilder::getTupleType(List<Type*>& types) { - return getOrCreate<TupleType>(types); + return getOrCreate<TupleType>(types.getArrayView()); } TypeType* ASTBuilder::getTypeType(Type* type) @@ -466,11 +511,11 @@ TypeType* ASTBuilder::getTypeType(Type* type) TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness( Type* type) { - return getOrCreate<TypeEqualityWitness>(type); + return getOrCreate<TypeEqualityWitness>(type, type); } -SubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness( +DeclaredSubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness( Type* subType, Type* superType, DeclRef<Decl> const& declRef) @@ -517,8 +562,8 @@ top: // Let's call the intermediate type here `x`, we know that the `b <: c` // witness is based on witnesses that `b <: x` and `x <: c`: // - auto bIsSubtypeOfXWitness = bIsTransitiveSubtypeOfCWitness->subToMid; - auto xIsSubtypeOfCWitness = bIsTransitiveSubtypeOfCWitness->midToSup; + auto bIsSubtypeOfXWitness = bIsTransitiveSubtypeOfCWitness->getSubToMid(); + auto xIsSubtypeOfCWitness = bIsTransitiveSubtypeOfCWitness->getMidToSup(); // We can recursively call this operation to produce a witness that // `a <: x`, based on the witnesses we already have for `a <: b` and `b <: x`: @@ -535,8 +580,8 @@ top: goto top; } - auto aType = aIsSubtypeOfBWitness->sub; - auto cType = bIsSubtypeOfCWitness->sup; + auto aType = aIsSubtypeOfBWitness->getSub(); + auto cType = bIsSubtypeOfCWitness->getSup(); // If the right-hand side is a conjunction witness for `B <: C` // of the form `(B <: X)&(B <: Y)`, then we have it that `C = X&Y` @@ -565,8 +610,8 @@ top: // the witness `W` that `B <: X&Y&...` as well as the index // `i` of `C` within the conjunction. // - auto bIsSubtypeOfConjunction = bIsSubtypeViaExtraction->conjunctionWitness; - auto indexOfCInConjunction = bIsSubtypeViaExtraction->indexInConjunction; + auto bIsSubtypeOfConjunction = bIsSubtypeViaExtraction->getConjunctionWitness(); + auto indexOfCInConjunction = bIsSubtypeViaExtraction->getIndexInConjunction(); // We lift the extraction to the outside of the composition, by // forming a witness for `A <: C` that is of the form @@ -591,24 +636,14 @@ top: // formal set of rules for the allowed structure of our witnesses to // guarantee that our simplifications are sufficient. - TransitiveSubtypeWitness* transitiveWitness = getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>( + TransitiveSubtypeWitness* transitiveWitness = getOrCreate<TransitiveSubtypeWitness>( aType, cType, aIsSubtypeOfBWitness, bIsSubtypeOfCWitness); - transitiveWitness->sub = aType; - transitiveWitness->sup = cType; - transitiveWitness->subToMid = aIsSubtypeOfBWitness; - transitiveWitness->midToSup = bIsSubtypeOfCWitness; - return transitiveWitness; } -ThisTypeSubtypeWitness* ASTBuilder::getThisTypeSubtypeWitness(Type* subType, Type* superType) -{ - return getOrCreate<ThisTypeSubtypeWitness>(subType, superType); -} - SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness( Type* subType, Type* superType, @@ -633,16 +668,11 @@ SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness( // // * What if the original witness is transitive? - auto witness = getOrCreateWithDefaultCtor<ExtractFromConjunctionSubtypeWitness>( + auto witness = getOrCreate<ExtractFromConjunctionSubtypeWitness>( subType, superType, conjunctionWitness, indexOfSuperTypeInConjunction); - - witness->sub = subType; - witness->sup = superType; - witness->conjunctionWitness = conjunctionWitness; - witness->indexInConjunction = indexOfSuperTypeInConjunction; return witness; } @@ -662,11 +692,11 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness( auto rExtract = as<ExtractFromConjunctionSubtypeWitness>(subIsRWitness); if(lExtract && rExtract) { - if (lExtract->indexInConjunction == 0 - && rExtract->indexInConjunction == 1) + if (lExtract->getIndexInConjunction() == 0 + && rExtract->getIndexInConjunction() == 1) { - auto lInner = lExtract->conjunctionWitness; - auto rInner = rExtract->conjunctionWitness; + auto lInner = lExtract->getConjunctionWitness(); + auto rInner = rExtract->getConjunctionWitness(); if (lInner == rInner) { return lInner; @@ -685,57 +715,30 @@ SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness( // witness) deeper, so that we have more chances to expose a // conjunction witness at higher levels. - auto witness = getOrCreateWithDefaultCtor<ConjunctionSubtypeWitness>( + auto witness = getOrCreate<ConjunctionSubtypeWitness>( sub, lAndR, subIsLWitness, subIsRWitness); - witness->componentWitnesses[0] = subIsLWitness; - witness->componentWitnesses[1] = subIsRWitness; - witness->sub = sub; - witness->sup = lAndR; return witness; } -bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const +DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl) { - if (hashCode != that.hashCode) return false; - if(type != that.type) return false; - if(operands.getCount() != that.operands.getCount()) return false; - for(Index i = 0; i < operands.getCount(); ++i) - { - // Note: we are comparing the operands directly for identity - // (pointer equality) rather than doing the `Val`-level - // equality check. - // - // The rationale here is that nodes that will be created - // via a `NodeDesc` *should* all be going through the - // deduplication path anyway, as should their operands. - // - if (operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; - } - return true; + return builder->getMemberDeclRef(parent, decl); } -void ASTBuilder::NodeDesc::init() + +thread_local ASTBuilder* gCurrentASTBuilder = nullptr; + +ASTBuilder* getCurrentASTBuilder() { - Hasher hasher; - hasher.hashValue(Int(type)); - for(Index i = 0; i < operands.getCount(); ++i) - { - // Note: we are hashing the raw pointer value rather - // than the content of the value node. This is done - // to match the semantics implemented for `==` on - // `NodeDesc`. - // - hasher.hashValue(operands[i].values.nodeOperand); - } - hashCode = hasher.getResult(); + return gCurrentASTBuilder; } -DeclRef<Decl> _getSpecializedDeclRef(ASTBuilder* builder, Decl* decl, Substitutions* subst) +void setCurrentASTBuilder(ASTBuilder* astBuilder) { - return builder->getSpecializedDeclRef(decl, subst); + gCurrentASTBuilder = astBuilder; } } // namespace Slang |
