diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-04 15:47:39 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-04 15:47:39 -0700 |
| commit | a2d90fb275962da84611160f8ddd74d934a68dbd (patch) | |
| tree | 066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-syntax.cpp | |
| parent | 17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff) | |
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val.
* Update project files
* Fix warning.
* Fix.
* Fix.
* Remove `Val::_equalsImplOverride`.
* Rmove `Val::_getHashCodeOverride`.
* Remove `semanticVisitor` param from `resolve`.
* Cleanups.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-syntax.cpp')
| -rw-r--r-- | source/slang/slang-syntax.cpp | 922 |
1 files changed, 108 insertions, 814 deletions
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 227e468d6..ae44e0c70 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -284,14 +284,14 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt { if(auto declaredSubtypeWitness = as<DeclaredSubtypeWitness>(subtypeWitness)) { - if(auto inheritanceDeclRef = declaredSubtypeWitness->declRef.as<InheritanceDecl>()) + if(auto inheritanceDeclRef = declaredSubtypeWitness->getDeclRef().as<InheritanceDecl>()) { // A conformance that was declared as part of an inheritance clause // will have built up a dictionary of the satisfying declarations // for each of its requirements. RequirementWitness requirementWitness; auto witnessTable = inheritanceDeclRef.getDecl()->witnessTable; - if(witnessTable && witnessTable->requirementDictionary.tryGetValue(requirementKey, requirementWitness)) + if(witnessTable && witnessTable->getRequirementDictionary().tryGetValue(requirementKey, requirementWitness)) { // The `inheritanceDeclRef` has substitutions applied to it that // *aren't* present in the `requirementWitness`, because it was @@ -338,7 +338,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // So, in order to get the *right* end result, we need to apply // the substitutions from the inheritance decl-ref to the witness. // - requirementWitness = requirementWitness.specialize(astBuilder, inheritanceDeclRef.getSubst()); + requirementWitness = requirementWitness.specialize(astBuilder, SubstitutionSet(inheritanceDeclRef)); return requirementWitness; } @@ -346,17 +346,17 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } else if (auto transitiveTypeWitness = as<TransitiveSubtypeWitness>(subtypeWitness)) { - if (auto declaredSubtypeWitnessMidToSup = as<DeclaredSubtypeWitness>(transitiveTypeWitness->midToSup)) + if (auto declaredSubtypeWitnessMidToSup = as<DeclaredSubtypeWitness>(transitiveTypeWitness->getMidToSup())) { - auto midKey = declaredSubtypeWitnessMidToSup->declRef; - auto midWitness = tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(transitiveTypeWitness->subToMid), midKey.getDecl()); + auto midKey = declaredSubtypeWitnessMidToSup->getDeclRef(); + auto midWitness = tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(transitiveTypeWitness->getSubToMid()), midKey.getDecl()); if (midWitness.getFlavor() == RequirementWitness::Flavor::witnessTable) { auto table = midWitness.getWitnessTable(); RequirementWitness result; - if (table->requirementDictionary.tryGetValue(requirementKey, result)) + if (table->getRequirementDictionary().tryGetValue(requirementKey, result)) { - result = result.specialize(astBuilder, midKey.getSubst()); + result = result.specialize(astBuilder, SubstitutionSet(midKey)); } return result; } @@ -364,15 +364,32 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } else if (auto extractFromConjunctionTypeWitness = as<ExtractFromConjunctionSubtypeWitness>(subtypeWitness)) { - if (auto conjunctionTypeWitness = as<ConjunctionSubtypeWitness>(extractFromConjunctionTypeWitness->conjunctionWitness)) + if (auto conjunctionTypeWitness = as<ConjunctionSubtypeWitness>(extractFromConjunctionTypeWitness->getConjunctionWitness())) { auto componentWitness = as<SubtypeWitness>( conjunctionTypeWitness->getComponentWitness( - extractFromConjunctionTypeWitness->indexInConjunction)); + extractFromConjunctionTypeWitness->getIndexInConjunction())); return tryLookUpRequirementWitness(astBuilder, componentWitness, requirementKey); } } + + // If we are looking for `ThisType`, just return subtype. + if (as<ThisTypeDecl>(requirementKey)) + { + RequirementWitness result; + result.m_flavor = RequirementWitness::Flavor::val; + result.m_val = subtypeWitness->getSub(); + return result; + } + // If we are looking for `ThisTypeConstraint`, just return the witness itself. + if (as<ThisTypeConstraintDecl>(requirementKey)) + { + RequirementWitness result; + result.m_flavor = RequirementWitness::Flavor::val; + result.m_val = subtypeWitness; + return result; + } // TODO: should handle the transitive case here too return RequirementWitness(); @@ -384,125 +401,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt void WitnessTable::add(Decl* decl, RequirementWitness const& witness) { - SLANG_ASSERT(!requirementDictionary.containsKey(decl)); - - requirementDictionary.add(decl, witness); - } - - // - - static Type* ExtractGenericArgType(Val* val) - { - auto type = as<Type>(val); - SLANG_RELEASE_ASSERT(type); - return type; - } - - static IntVal* ExtractGenericArgInteger(Val* val) - { - auto intVal = as<IntVal>(val); - SLANG_RELEASE_ASSERT(intVal); - return intVal; - } - - DeclRef<Decl> createDefaultSubstitutionsIfNeeded( - ASTBuilder* astBuilder, - SemanticsVisitor* semantics, - DeclRef<Decl> declRef) - { - // It is possible that `declRef` refers to a generic type, - // but does not specify arguments for its generic parameters. - // (E.g., this happens when referring to a generic type from - // within its own member functions). To handle this case, - // we will construct a default specialization at the use - // site if needed. - // - // This same logic should also apply to declarations nested - // more than one level inside of a generic (e.g., a `typdef` - // inside of a generic `struct`). - // - // Similarly, it needs to work for multiple levels of - // nested generics. - // - - // First, we collect all the generic parents. - ShortList<GenericDecl*> genericParents; - Decl* dd = declRef.getDecl(); - for (;;) - { - Decl* childDecl = dd; - Decl* parentDecl = dd->parentDecl; - if (!parentDecl) - break; - - dd = parentDecl; - - if (auto genericParentDecl = as<GenericDecl>(parentDecl)) - { - // Don't specialize any parameters of a generic. - if (childDecl != genericParentDecl->inner) - break; - genericParents.add(genericParentDecl); - } - } - - - Substitutions* outerSubst = nullptr; - for (Index i = genericParents.getCount()-1; i>=0; i--) - { - Decl* childDecl = genericParents[i]->inner; - Decl* parentDecl = genericParents[i]; - - if(auto genericParentDecl = as<GenericDecl>(parentDecl)) - { - // Don't specialize any parameters of a generic. - if(childDecl != genericParentDecl->inner) - break; - - // We have a generic ancestor, but do we have an substitutions for it? - GenericSubstitution* foundSubst = nullptr; - for(auto s = declRef.getSubst(); s; s = s->getOuter()) - { - auto genSubst = as<GenericSubstitution>(s); - if(!genSubst) - continue; - - if(genSubst->getGenericDecl() != genericParentDecl) - continue; - - // Okay, we found a matching substitution, - // so we just grab the args from the matching subst instead. - foundSubst = genSubst; - if (foundSubst->getOuter() != outerSubst) - { - foundSubst = astBuilder->getOrCreateGenericSubstitution( - outerSubst, foundSubst->getGenericDecl(), foundSubst->getArgs()); - } - - break; - } - - if(!foundSubst) - { - Substitutions* newSubst = createDefaultSubstitutionsForGeneric( - astBuilder, - semantics, - genericParentDecl, - outerSubst); - outerSubst = newSubst; - } - else - { - outerSubst = foundSubst; - } - } - } - - if(!outerSubst) - return declRef; - - int diff = 0; - return declRef.substituteImpl(astBuilder, outerSubst, &diff); + m_requirements.add(KeyValuePair<Decl*, RequirementWitness>(decl, witness)); + m_requirementDictionary.add(decl, witness); } // TODO: need to figure out how to unify this with the logic @@ -511,245 +411,73 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt ASTBuilder* astBuilder, DeclRef<Decl> declRef) { - declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); - if (auto builtinMod = declRef.getDecl()->findModifier<BuiltinTypeModifier>()) { - auto type = astBuilder->getOrCreate<BasicExpressionType>(builtinMod->tag); - type->declRef = declRef; + // Always create builtin types in global AST builder. + if (astBuilder->getSharedASTBuilder()->getInnerASTBuilder() != astBuilder) + return DeclRefType::create(astBuilder->getSharedASTBuilder()->getInnerASTBuilder(), declRef); + + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); + auto type = astBuilder->getOrCreate<BasicExpressionType>(declRef.declRefBase); return type; } else if (auto magicMod = declRef.getDecl()->findModifier<MagicTypeModifier>()) { - GenericSubstitution* subst = nullptr; - for(auto s = declRef.getSubst(); s; s = s->getOuter()) - { - if(auto genericSubst = as<GenericSubstitution>(s)) - { - subst = genericSubst; - break; - } - } + // Always create builtin types in global AST builder. + if (astBuilder->getSharedASTBuilder()->getInnerASTBuilder() != astBuilder) + return DeclRefType::create(astBuilder->getSharedASTBuilder()->getInnerASTBuilder(), declRef); - if (magicMod->magicName == "SamplerState") - { - auto type = astBuilder->getOrCreate<SamplerStateType>(SamplerStateFlavor(magicMod->tag)); - type->declRef = declRef; - return type; - } - else if (magicMod->magicName == "Vector") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() == 2); - auto vecType = astBuilder->getOrCreate<VectorExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1])); - vecType->declRef = declRef; - vecType->elementType = ExtractGenericArgType(subst->getArgs()[0]); - vecType->elementCount = ExtractGenericArgInteger(subst->getArgs()[1]); - return vecType; - } - else if (magicMod->magicName == "ArrayType") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() == 2); - auto vecType = astBuilder->getOrCreate<ArrayExpressionType>(ExtractGenericArgType(subst->getArgs()[0]), ExtractGenericArgInteger(subst->getArgs()[1])); - vecType->declRef = declRef; - return vecType; - } - else if (magicMod->magicName == "Matrix") + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); + auto classInfo = astBuilder->findSyntaxClass(magicMod->magicName.getUnownedSlice()); + if (!classInfo.classInfo) { - SLANG_ASSERT(subst && subst->getArgs().getCount() == 3); - auto matType = astBuilder->getOrCreate<MatrixExpressionType>( - ExtractGenericArgType(subst->getArgs()[0]), - ExtractGenericArgInteger(subst->getArgs()[1]), - ExtractGenericArgInteger(subst->getArgs()[2])); - matType->declRef = declRef; - return matType; + SLANG_UNEXPECTED("unhandled type"); } - else if (magicMod->magicName == "TensorViewType") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() == 1); - auto vecType = astBuilder->getOrCreate<TensorViewType>(ExtractGenericArgType(subst->getArgs()[0])); - vecType->declRef = declRef; - return vecType; - } - else if (magicMod->magicName == "Texture") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); - auto textureTag = TextureFlavor(magicMod->tag); - Val* sampleCount = nullptr; - if (textureTag.isMultisample()) + ValNodeDesc nodeDesc = {}; + nodeDesc.type = (ASTNodeType)classInfo.classInfo->m_classId; + nodeDesc.operands.add(ValNodeOperand(declRef)); + nodeDesc.init(); + NodeBase* type = astBuilder->_getOrCreateImpl(nodeDesc, [&]() { - if (subst->getArgs().getCount() >= 2) - sampleCount = ExtractGenericArgInteger(subst->getArgs().getLast()); - } - auto textureType = astBuilder->getOrCreate<TextureType>( - textureTag, - ExtractGenericArgType(subst->getArgs()[0]), - sampleCount); - textureType->declRef = declRef; - return textureType; - } - else if (magicMod->magicName == "TextureSampler") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); - auto textureType = astBuilder->getOrCreate<TextureSamplerType>( - TextureFlavor(magicMod->tag), - ExtractGenericArgType(subst->getArgs()[0])); - textureType->declRef = declRef; - return textureType; - } - else if (magicMod->magicName == "GLSLImageType") - { - SLANG_ASSERT(subst && subst->getArgs().getCount() >= 1); - auto textureType = astBuilder->getOrCreate<GLSLImageType>( - TextureFlavor(magicMod->tag), - ExtractGenericArgType(subst->getArgs()[0])); - textureType->declRef = declRef; - return textureType; - } - else if (magicMod->magicName == "FeedbackType") + auto resultNode = as<DeclRefType>(classInfo.createInstance(astBuilder)); + resultNode->setOperands(declRef); + return resultNode; + }); + if (!type) { - SLANG_ASSERT(subst == nullptr); - auto type = astBuilder->getOrCreateWithDefaultCtor<FeedbackType>(magicMod->tag); - type->declRef = declRef; - type->kind = FeedbackType::Kind(magicMod->tag); - return type; + SLANG_UNEXPECTED("constructor failure"); } - // TODO: eventually everything should follow this pattern, - // and we can drive the dispatch with a table instead - // of this ridiculously slow `if` cascade. - - #define CASE(n, T) \ - else if (magicMod->magicName == #n) \ - { \ - auto type = astBuilder->getOrCreateWithDefaultCtor<T>( \ - declRef.getDecl(), declRef.getSubst()); \ - type->declRef = declRef; \ - return type; \ - } - - CASE(HLSLInputPatchType, HLSLInputPatchType) - CASE(HLSLOutputPatchType, HLSLOutputPatchType) - - #undef CASE - - #define CASE(n, T) \ - else if (magicMod->magicName == #n) \ - { \ - SLANG_ASSERT(subst && subst->getArgs().getCount() == 1); \ - auto type = \ - astBuilder->getOrCreateWithDefaultCtor<T>(ExtractGenericArgType(subst->getArgs()[0])); \ - type->elementType = ExtractGenericArgType(subst->getArgs()[0]); \ - type->declRef = declRef; \ - return type; \ - } - - CASE(ConstantBuffer, ConstantBufferType) - CASE(TextureBuffer, TextureBufferType) - CASE(ParameterBlockType, ParameterBlockType) - CASE(GLSLInputParameterGroupType, GLSLInputParameterGroupType) - CASE(GLSLOutputParameterGroupType, GLSLOutputParameterGroupType) - CASE(GLSLShaderStorageBufferType, GLSLShaderStorageBufferType) - - CASE(HLSLStructuredBufferType, HLSLStructuredBufferType) - CASE(HLSLRWStructuredBufferType, HLSLRWStructuredBufferType) - CASE(HLSLRasterizerOrderedStructuredBufferType, HLSLRasterizerOrderedStructuredBufferType) - CASE(HLSLAppendStructuredBufferType, HLSLAppendStructuredBufferType) - CASE(HLSLConsumeStructuredBufferType, HLSLConsumeStructuredBufferType) - - CASE(HLSLPointStreamType, HLSLPointStreamType) - CASE(HLSLLineStreamType, HLSLLineStreamType) - CASE(HLSLTriangleStreamType, HLSLTriangleStreamType) - - #undef CASE - - // "magic" builtin types which have no generic parameters - #define CASE(n,T) \ - else if(magicMod->magicName == #n) { \ - auto type = astBuilder->getOrCreate<T>(); \ - type->declRef = declRef; \ - return type; \ - } - - CASE(HLSLByteAddressBufferType, HLSLByteAddressBufferType) - CASE(HLSLRWByteAddressBufferType, HLSLRWByteAddressBufferType) - CASE(HLSLRasterizerOrderedByteAddressBufferType, HLSLRasterizerOrderedByteAddressBufferType) - CASE(UntypedBufferResourceType, UntypedBufferResourceType) - - CASE(GLSLInputAttachmentType, GLSLInputAttachmentType) - - #undef CASE - - else + auto declRefType = dynamicCast<DeclRefType>(type); + if (!declRefType) { - auto classInfo = astBuilder->findSyntaxClass(magicMod->magicName.getUnownedSlice()); - if (!classInfo.classInfo) - { - SLANG_UNEXPECTED("unhandled type"); - } - - NodeBase* type = classInfo.createInstance(astBuilder); - if (!type) - { - SLANG_UNEXPECTED("constructor failure"); - } - - auto declRefType = dynamicCast<DeclRefType>(type); - if (!declRefType) - { - SLANG_UNEXPECTED("expected a declaration reference type"); - } - declRefType->declRef = declRef; - return declRefType; + SLANG_UNEXPECTED("expected a declaration reference type"); } + return declRefType; + } + else if (as<ThisTypeDecl>(declRef.getDecl()) && as<DirectDeclRef>(declRef.declRefBase)) + { + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); + + return astBuilder->getOrCreate<ThisType>(declRef.declRefBase); } else { + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef); + return astBuilder->getOrCreate<DeclRefType>(declRef.declRefBase); } } // - GenericSubstitution* findInnerMostGenericSubstitution(Substitutions* subst) + Val::OperandView<Val> findInnerMostGenericArgs(SubstitutionSet subst) { - for(Substitutions* s = subst; s; s = s->getOuter()) - { - if(auto genericSubst = as<GenericSubstitution>(s)) - return genericSubst; - } - return nullptr; - } - - - // DeclRefBase - - Type* DeclRefBase::substitute(ASTBuilder* astBuilder, Type* type) const - { - // Note that type can be nullptr, and so this function can return nullptr (although only correctly when no substitutions) - - // No substitutions? Easy. - if (!substitutions) - return type; - - SLANG_ASSERT(type); - - // Otherwise we need to recurse on the type structure - // and apply substitutions where it makes sense - return Slang::as<Type>(type->substitute(astBuilder, substitutions)); - } - - DeclRefBase* DeclRefBase::substitute(ASTBuilder* astBuilder, DeclRefBase* declRef) const - { - if(!substitutions) - return declRef; - - int diff = 0; - return declRef->substituteImpl(astBuilder, substitutions, &diff); - } - - SubstExpr<Expr> DeclRefBase::substitute(ASTBuilder* /* astBuilder*/, Expr* expr) const - { - return SubstExpr<Expr>(expr, substitutions); + if (!subst.declRef) + return Val::OperandView<Val>(); + if (auto genApp = subst.findGenericAppDeclRef()) + return genApp->getArgs(); + return Val::OperandView<Val>(); } SubstExpr<Expr> substituteExpr(SubstitutionSet const& substs, Expr* expr) @@ -764,7 +492,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt int diff = 0; auto declRefBase = declRef.substituteImpl(astBuilder, substs, &diff); - return astBuilder->getSpecializedDeclRef<Decl>(declRefBase.getDecl(), declRefBase.getSubst()); + return declRefBase; } Type* substituteType(SubstitutionSet const& substs, ASTBuilder* astBuilder, Type* type) @@ -790,332 +518,13 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - Substitutions* specializeSubstitutionsShallow( - ASTBuilder* astBuilder, - Substitutions* substToSpecialize, - Substitutions* substsToApply, - Substitutions* restSubst, - int* ioDiff) - { - SLANG_ASSERT(substToSpecialize); - return substToSpecialize->applySubstitutionsShallow(astBuilder, substsToApply, restSubst, ioDiff); - } - - // Construct new substitutions to apply to a declaration, - // based on a provided substitution set to be applied - Substitutions* specializeSubstitutions( - ASTBuilder* astBuilder, - Decl* declToSpecialize, - Substitutions* substsToSpecialize, - Substitutions* substsToApply, - int* ioDiff) - { - // No declaration? Then nothing to specialize. - if(!declToSpecialize) - return nullptr; - - // No (remaining) substitutions to apply? Then we are done. - if(!substsToApply) - return substsToSpecialize; - - // Walk the hierarchy of the declaration to determine what specializations might apply. - // We assume that the `substsToSpecialize` must be aligned with the ancestor - // hierarchy of `declToSpecialize` such that if, e.g., the `declToSpecialize` is - // nested directly in a generic, then `substToSpecialize` will either start with - // the corresponding `GenericSubstitution` or there will be *no* generic substitutions - // corresponding to that decl. - for(Decl* ancestorDecl = declToSpecialize; ancestorDecl; ancestorDecl = ancestorDecl->parentDecl) - { - if(auto ancestorGenericDecl = as<GenericDecl>(ancestorDecl)) - { - // The declaration is nested inside a generic. - // Does it already have a specialization for that generic? - if(auto specGenericSubst = as<GenericSubstitution>(substsToSpecialize)) - { - if(specGenericSubst->getGenericDecl() == ancestorGenericDecl) - { - // Yes. We have an existing specialization, so we will - // keep one matching it in place. - int diff = 0; - auto restSubst = specializeSubstitutions( - astBuilder, - ancestorGenericDecl->parentDecl, - specGenericSubst->getOuter(), - substsToApply, - &diff); - - auto firstSubst = specializeSubstitutionsShallow( - astBuilder, - specGenericSubst, - substsToApply, - restSubst, - &diff); - - *ioDiff += diff; - return firstSubst; - } - } - - // If the declaration is not already specialized - // for the given generic, then see if we are trying - // to *apply* such specializations to it. - // - // TODO: The way we handle things right now with - // "default" specializations, this case shouldn't - // actually come up. - // - for(auto s = substsToApply; s; s = s->getOuter()) - { - auto appGenericSubst = as<GenericSubstitution>(s); - if(!appGenericSubst) - continue; - - if(appGenericSubst->getGenericDecl() != ancestorGenericDecl) - continue; - - // The substitutions we are applying are trying - // to specialize this generic, but we don't already - // have a generic substitution in place. - // We will need to create one. - - int diff = 0; - auto restSubst = specializeSubstitutions( - astBuilder, - ancestorGenericDecl->parentDecl, - substsToSpecialize, - substsToApply, - &diff); - - GenericSubstitution* firstSubst = astBuilder->getOrCreateGenericSubstitution( - restSubst, ancestorGenericDecl, appGenericSubst->getArgs()); - - (*ioDiff)++; - return firstSubst; - } - } - else if(auto ancestorInterfaceDecl = as<InterfaceDecl>(ancestorDecl)) - { - // The task is basically the same as for the generic case: - // We want to see if there is any existing substitution that - // applies to this declaration, and use that if possible. - - // The declaration is nested inside a generic. - // Does it already have a specialization for that generic? - if(auto specThisTypeSubst = as<ThisTypeSubstitution>(substsToSpecialize)) - { - if(specThisTypeSubst->interfaceDecl == ancestorInterfaceDecl) - { - // Yes. We have an existing specialization, so we will - // keep one matching it in place. - int diff = 0; - auto restSubst = specializeSubstitutions( - astBuilder, - ancestorInterfaceDecl->parentDecl, - specThisTypeSubst->getOuter(), - substsToApply, - &diff); - - auto firstSubst = specializeSubstitutionsShallow( - astBuilder, - specThisTypeSubst, - substsToApply, - restSubst, - &diff); - - *ioDiff += diff; - return firstSubst; - } - } - - // Otherwise, check if we are trying to apply - // a this-type substitution to the given interface - // - // Note: We want to skip the ThisTypeSubstitution that specializes - // declToSpecialize itself (when declToSpecialize is an interface - // decl and the subst specializes it), and only pull the - // ThisTypeSubstitution when the decl is referencing a child of - // the interface decl being specialized. This is because - // by default an interface declref type is a "free" existential - // type that shouldn't be specialized by someone else, unless - // there is an "implicit" ThisType reference preceeding a child - // reference. - if (declToSpecialize != ancestorInterfaceDecl) - { - for (auto s = substsToApply; s; s = s->getOuter()) - { - auto appThisTypeSubst = as<ThisTypeSubstitution>(s); - if (!appThisTypeSubst) - continue; - - if (appThisTypeSubst->interfaceDecl != ancestorInterfaceDecl) - continue; - - int diff = 0; - auto restSubst = specializeSubstitutions( - astBuilder, - ancestorInterfaceDecl->parentDecl, - substsToSpecialize, - substsToApply, - &diff); - - ThisTypeSubstitution* firstSubst = astBuilder->getOrCreateThisTypeSubstitution( - ancestorInterfaceDecl, appThisTypeSubst->witness, restSubst); - - (*ioDiff)++; - return firstSubst; - } - } - } - } - - // If we reach here then we've walked the full hierarchy up from - // `declToSpecialize` and either didn't run into an generic/interface - // declarations, or we didn't find any attempt to specialize them - // in either substitution. - // - // As an invariant, there should *not* be any generic or this-type - // substitutions in `substToSpecialize`, because otherwise they - // would be specializations that don't actually apply to the given - // declaration. - // - // Note: this does *not* mean that `substsToApply` doesn't have - // any generic or this-type substitutions; it just means that none - // of them were applicable. - // - return nullptr; - } - - DeclRefBase* DeclRefBase::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet substSet, int* ioDiff) const - { - // Nothing to do when we have no declaration. - if(!decl) - return const_cast<DeclRefBase*>(this); - - // Apply the given substitutions to any specializations - // that have already been applied to this declaration. - int diff = 0; - - auto substSubst = specializeSubstitutions( - astBuilder, - decl, - substitutions, - substSet.substitutions, - &diff); - - if (!diff) - return const_cast<DeclRefBase*>(this); - - *ioDiff += diff; - - DeclRefBase* substDeclRef = astBuilder->getSpecializedDeclRef(decl, substSubst); - - // TODO: The old code here used to try to translate a decl-ref - // to an associated type in a decl-ref for the concrete type - // in a particular implementation. - // - // I have only kept that logic in `DeclRefType::SubstituteImpl`, - // but it may turn out it is needed here too. - - return substDeclRef; - } - - bool DeclRefBase::_equalsValOverride(Val* val) - { - if (auto otherDeclRef = as<DeclRefBase>(val)) - return equals(otherDeclRef); - return false; - } - - // Check if this is an equivalent declaration reference to another - bool DeclRefBase::equals(DeclRefBase* declRef) const - { - if (!declRef) - return false; - if (decl != declRef->decl) - return false; - if (!SubstitutionSet(substitutions).equals(declRef->substitutions)) - return false; - - return true; - } - - // Convenience accessors for common properties of declarations - Name* DeclRefBase::getName() const - { - return decl->nameAndLoc.name; - } - SourceLoc DeclRefBase::getNameLoc() const - { - return decl->nameAndLoc.loc; - } - SourceLoc DeclRefBase::getLoc() const - { - return decl->loc; - } - - DeclRefBase* DeclRefBase::getParent(ASTBuilder* astBuilder) const - { - // Want access to the free function (the 'as' method by default gets priority) - // Can access as method with this->as because it removes any ambiguity. - using Slang::as; - - auto parentDecl = decl->parentDecl; - if (!parentDecl) - return nullptr; - - // Default is to apply the same set of substitutions/specializations - // to the parent declaration as were applied to the child. - Substitutions* substToApply = substitutions; - - if(auto interfaceDecl = as<InterfaceDecl>(decl)) - { - // The declaration being referenced is an `interface` declaration, - // and there might be a this-type substitution in place. - // A reference to the parent of the interface declaration - // should not include that substitution. - if(auto thisTypeSubst = as<ThisTypeSubstitution>(substToApply)) - { - if(thisTypeSubst->interfaceDecl == interfaceDecl) - { - // Strip away that specializations that apply to the interface. - substToApply = thisTypeSubst->getOuter(); - } - } - } - - if (auto parentGenericDecl = as<GenericDecl>(parentDecl)) - { - // The parent of this declaration is a generic, which means - // that the decl-ref to the current declaration might include - // substitutions that specialize the generic parameters. - // A decl-ref to the parent generic should *not* include - // those substitutions. - // - if(auto genericSubst = as<GenericSubstitution>(substToApply)) - { - if(genericSubst->getGenericDecl() == parentGenericDecl) - { - // Strip away the specializations that were applied to the parent. - substToApply = genericSubst->getOuter(); - } - } - } - - return astBuilder->getSpecializedDeclRef(parentDecl, substToApply); - } - - HashCode DeclRefBase::getHashCode() const - { - return combineHash(PointerHash<1>::getHashCode(decl), SubstitutionSet(substitutions).getHashCode()); - } - // IntVal IntegerLiteralValue getIntVal(IntVal* val) { if (auto constantVal = as<ConstantIntVal>(val)) { - return constantVal->value; + return constantVal->getValue(); } SLANG_UNEXPECTED("needed a known integer value"); //return 0; @@ -1125,14 +534,22 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // HLSLPatchType + Val* getGenericArg(DeclRef<Decl> declRef, Index index) + { + auto subst = SubstitutionSet(declRef).findGenericAppDeclRef(); + if (index < subst->getArgs().getCount()) + return subst->getArgs()[index]; + return nullptr; + } + Type* HLSLPatchType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(getGenericArg(getDeclRef(), 0)); } IntVal* HLSLPatchType::getElementCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); + return as<IntVal>(getGenericArg(getDeclRef(), 1)); } // MeshOutputType @@ -1143,12 +560,12 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt Type* MeshOutputType::getElementType() { - return as<Type>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[0]); + return as<Type>(getGenericArg(getDeclRef(), 0)); } IntVal* MeshOutputType::getMaxElementCount() { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.getSubst())->getArgs()[1]); + return as<IntVal>(getGenericArg(getDeclRef(), 1)); } // Constructors for types @@ -1174,17 +591,16 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt { DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, declRef).as<TypeDefDecl>(); - return astBuilder->create<NamedExpressionType>(specializedDeclRef); + return astBuilder->getOrCreate<NamedExpressionType>(specializedDeclRef); } FuncType* getFuncType( ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declRef) { - FuncType* funcType = astBuilder->create<FuncType>(); - - funcType->resultType = getResultType(astBuilder, declRef); - funcType->errorType = getErrorCodeType(astBuilder, declRef); + List<Type*> paramTypes; + auto resultType = getResultType(astBuilder, declRef); + auto errorType = getErrorCodeType(astBuilder, declRef); for (auto paramDeclRef : getParameters(astBuilder, declRef)) { auto paramDecl = paramDeclRef.getDecl(); @@ -1204,9 +620,10 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt paramType = astBuilder->getOutType(paramType); } } - funcType->paramTypes.add(paramType); + paramTypes.add(paramType); } + FuncType* funcType = astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType); return funcType; } @@ -1214,40 +631,34 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt ASTBuilder* astBuilder, DeclRef<GenericDecl> const& declRef) { - return astBuilder->create<GenericDeclRefType>(declRef); + return astBuilder->getOrCreate<GenericDeclRefType>(declRef); } NamespaceType* getNamespaceType( ASTBuilder* astBuilder, DeclRef<NamespaceDeclBase> const& declRef) { - auto type = astBuilder->create<NamespaceType>(); - type->declRef = declRef; + auto type = astBuilder->getOrCreate<NamespaceType>(declRef); return type; } SamplerStateType* getSamplerStateType( ASTBuilder* astBuilder) { - return astBuilder->create<SamplerStateType>(); + return astBuilder->getSamplerStateType(); } - ThisTypeSubstitution* findThisTypeSubstitution( - const Substitutions* substs, + SubtypeWitness* findThisTypeWitness( + SubstitutionSet substs, InterfaceDecl* interfaceDecl) { - for(const Substitutions* s = substs; s; s = s->getOuter()) + auto lookupDeclRef = substs.findLookupDeclRef(); + if (!lookupDeclRef) + return nullptr; + if (lookupDeclRef->getSupDecl() == interfaceDecl) { - auto thisTypeSubst = as<ThisTypeSubstitution>(s); - if(!thisTypeSubst) - continue; - - if(thisTypeSubst->interfaceDecl != interfaceDecl) - continue; - - return const_cast<ThisTypeSubstitution*>(thisTypeSubst); + return lookupDeclRef->getWitness(); } - return nullptr; } @@ -1259,20 +670,16 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt auto substAssocTypeDecl = substDeclRef.getDecl(); - for (auto s = substDeclRef.getSubst(); s; s = s->getOuter()) + if (auto lookupDeclRef = SubstitutionSet(substDeclRef).findLookupDeclRef()) { - auto thisSubst = as<ThisTypeSubstitution>(s); - if (!thisSubst) - continue; - if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl)) { - if (thisSubst->interfaceDecl == interfaceDecl) + if (lookupDeclRef->getSupDecl() == interfaceDecl) { // We need to look up the declaration that satisfies // the requirement named by the associated type. Decl* requirementKey = substAssocTypeDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(builder, thisSubst->witness, requirementKey); + RequirementWitness requirementWitness = tryLookUpRequirementWitness(builder, lookupDeclRef->getWitness(), requirementKey); switch (requirementWitness.getFlavor()) { default: @@ -1296,17 +703,17 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt if (builtinReq->kind != BuiltinRequirementKind::DifferentialType) return nullptr; // Is the concrete type a Differential associated type? - auto innerDeclRefType = as<DeclRefType>(thisSubst->witness->sub); + auto innerDeclRefType = as<DeclRefType>(lookupDeclRef->getWitness()->getSub()); if (!innerDeclRefType) return nullptr; - auto innerBuiltinReq = innerDeclRefType->declRef.getDecl()->findModifier<BuiltinRequirementModifier>(); + auto innerBuiltinReq = innerDeclRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>(); if (!innerBuiltinReq) return nullptr; if (innerBuiltinReq->kind != BuiltinRequirementKind::DifferentialType) return nullptr; - if (!innerDeclRefType->declRef.equals(declRef)) + if (!innerDeclRefType->getDeclRef().equals(declRef)) { - auto result = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(builder, innerDeclRefType->declRef); + auto result = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(builder, innerDeclRefType->getDeclRef()); if (result) return result; } @@ -1320,119 +727,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - String DeclRefBase::toString() const - { - StringBuilder builder; - toText(builder); - return std::move(builder); - } - - // Prints a partially qualified type name with generic substitutions. - void _printNestedDecl(const Substitutions* substitutions, const Decl* decl, StringBuilder& out) - { - // If there is a parent scope for the declaration, print it first. - // Exclude top-level namespaces like `tu0` or `core`. - if (decl->parentDecl && !Slang::as<ModuleDecl>(decl->parentDecl)) - { - auto parentGeneric = Slang::as<GenericDecl>(decl->parentDecl); - - // Exclude function or operator names. - // Avoids excessively verbose messages like `func<T>(func::T)` - if (!(parentGeneric && Slang::as<CallableDecl>(parentGeneric->inner))) - { - _printNestedDecl(substitutions, decl->parentDecl, out); - - // If the parent is a generic for this type, skip *this* type. - // Avoids duplicate types like `MyType<T>::MyType` - if (parentGeneric && parentGeneric->inner == decl) - return; - - out << "."; - } - } - // If we have a ThisTypeSubstitution to an interface decl, print the substituted sub - // type instead. - for (;;) - { - if (auto interfaceDecl = const_cast<InterfaceDecl*>(as<InterfaceDecl>(decl))) - { - if (auto thisSubst = findThisTypeSubstitution(substitutions, interfaceDecl)) - { - if (auto subTypeWitness = as<SubtypeWitness>(thisSubst->witness)) - { - out << subTypeWitness->sub; - break; - } - } - } - // Otherwise, just print this type's name. - auto name = decl->getName(); - if (name) - { - out << name->text; - } - break; - } - - // Look for generic substitutions on this type. - for (const Substitutions* subst = substitutions; subst; subst = subst->getOuter()) - { - auto genericSubstitution = Slang::as<GenericSubstitution>(subst); - if (!genericSubstitution) - continue; - - // If the substitution is for this type, print it. - if (genericSubstitution->getGenericDecl() == decl) - { - out << "<"; - bool isFirst = true; - for (const auto& it : genericSubstitution->getArgs()) - { - // Don't print out witnesses. - if (as<Witness>(it)) - continue; - if (!isFirst) - out << ", "; - isFirst = false; - it->toText(out); - } - out << ">"; - - break; - } - } - } - - void DeclRefBase::toText(StringBuilder& out) const - { - if (decl) - { - _printNestedDecl(substitutions, decl, out); - } - } - - bool SubstitutionSet::equals(const SubstitutionSet& substSet) const - { - if (substitutions == substSet.substitutions) - { - return true; - } - if (substitutions == nullptr || substSet.substitutions == nullptr) - { - return false; - } - return substitutions->equals(substSet.substitutions); - } - - HashCode SubstitutionSet::getHashCode() const - { - HashCode rs = 0; - if (substitutions) - rs = combineHash(rs, substitutions->getHashCode()); - return rs; - } - - ModuleDecl* getModuleDecl(Decl* decl) { for( auto dd = decl; dd; dd = dd->parentDecl ) |
