diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-builder.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 187 |
3 files changed, 134 insertions, 61 deletions
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 029c24216..5b4ec5538 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -336,7 +336,7 @@ public: case ASTNodeType::ThisTypeDecl: case ASTNodeType::ExtensionDecl: case ASTNodeType::AssocTypeDecl: - return getLookupDeclRef(lookupDeclRef->getLookupSource(), lookupDeclRef->getWitness(), memberDecl); + return getLookupDeclRef(lookupDeclRef->getLookupSource(), lookupDeclRef->getWitness(), memberDecl).template as<T>(); default: break; } @@ -396,13 +396,13 @@ public: return getOrCreate<GenericAppDeclRef>(innerDecl, genericDeclRef, args); } - LookupDeclRef* getLookupDeclRef(Type* base, SubtypeWitness* subtypeWitness, Decl* declToLookup) + DeclRef<Decl> getLookupDeclRef(Type* base, SubtypeWitness* subtypeWitness, Decl* declToLookup) { auto result = getOrCreate<LookupDeclRef>(declToLookup, base, subtypeWitness); return result; } - LookupDeclRef* getLookupDeclRef(SubtypeWitness* subtypeWitness, Decl* declToLookup) + DeclRef<Decl> getLookupDeclRef(SubtypeWitness* subtypeWitness, Decl* declToLookup) { return getLookupDeclRef(subtypeWitness->getSub(), subtypeWitness, declToLookup); } diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 2dc746e09..47cd68b9e 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -559,7 +559,7 @@ DeclRef<ThisTypeDecl> ExtractExistentialType::getThisTypeDeclRef() } SLANG_ASSERT(thisTypeDecl); - DeclRef<ThisTypeDecl> specialiedInterfaceDeclRef = getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl); + DeclRef<ThisTypeDecl> specialiedInterfaceDeclRef = getCurrentASTBuilder()->getLookupDeclRef(openedWitness, thisTypeDecl).as<ThisTypeDecl>(); this->cachedThisTypeDeclRef = specialiedInterfaceDeclRef; return specialiedInterfaceDeclRef; diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 969c87981..cd25e9d66 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2247,16 +2247,35 @@ namespace Slang // from requirementDeclRef to get the generic arguments for the outer generic, and // apply it to the newly synthesized decl. SubstitutionSet substSet; + Type* thisType = nullptr; if (auto thisWitness = findThisTypeWitness( SubstitutionSet(requirementDeclRef), as<InterfaceDecl>(requirementDeclRef.getParent()).getDecl())) { - if (auto declRefType = as<DeclRefType>(thisWitness->getSub())) + thisType = thisWitness->getSub(); + if (auto declRefType = as<DeclRefType>(thisType)) { substSet = SubstitutionSet(declRefType->getDeclRef()); } } - auto satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getMemberDeclRef(substSet.declRef, aggTypeDecl)); + if (!substSet.declRef) + return false; + Type* satisfyingType = nullptr; + if (substSet.declRef->getDecl() == context->parentDecl) + { + // The type we are synthesizing conformance for is direct inside a type itself. + // We need to copy the outer generic arguments to the synthesized type. + satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getMemberDeclRef(substSet.declRef, aggTypeDecl)); + } + else if (auto parentExtDecl = as<ExtensionDecl>(context->parentDecl)) + { + // The type is defined in an extension, we need to form a declref to the parent + // extension from the requirementDeclRef. + auto extDeclRef = applyExtensionToType(parentExtDecl, thisType); + satisfyingType = DeclRefType::create(m_astBuilder, m_astBuilder->getMemberDeclRef(extDeclRef, aggTypeDecl)); + } + if (!satisfyingType) + return false; // Helper function to add a `diffType` field into the synthesized type for the original // `member`. @@ -2683,6 +2702,23 @@ namespace Slang _registerBuiltinDeclsRec(session, decl); } + void discoverExtensionDecls(List<ExtensionDecl*>& decls, Decl* parent) + { + if (auto extDecl = as<ExtensionDecl>(parent)) + decls.add(extDecl); + if (auto containerDecl = as<ContainerDecl>(parent)) + { + for (auto child : containerDecl->members) + { + discoverExtensionDecls(decls, child); + } + } + if (auto genericDecl = as<GenericDecl>(parent)) + { + discoverExtensionDecls(decls, genericDecl->inner); + } + } + void SemanticsDeclVisitorBase::checkModule(ModuleDecl* moduleDecl) { // When we are dealing with code from the standard library, @@ -2824,6 +2860,23 @@ namespace Slang DeclCheckState::DefinitionChecked, DeclCheckState::CapabilityChecked, }; + + // Discover and check all extension decls before anything else. + List<ExtensionDecl*> extensionDecls; + discoverExtensionDecls(extensionDecls, moduleDecl); + for (auto s : states) + { + for (auto extensionDecl : extensionDecls) + { + ensureDecl(extensionDecl, s); + } + // We only need to check extension decls up to ReadyForLookup + // so they are properly registered in type inheritance infos. + if (s == DeclCheckState::ReadyForLookup) + break; + } + + // With extensions taken care of, we can now check the remaining decls. for(auto s : states) { // When advancing to state `s` we will recursively @@ -5183,12 +5236,12 @@ namespace Slang } else if (auto funcDeclRef = requirementDeclRef.as<FuncDecl>()) { - synFunc = as<FuncDecl>(synthesizeMethodSignatureForRequirementWitness( - context, funcDeclRef, synArgs, synThis)); + synFunc = as<FuncDecl>(synthesizeMethodSignatureForRequirementWitness( + context, funcDeclRef, synArgs, synThis)); } - + SLANG_ASSERT(synFunc); - + addModifier(synFunc, m_astBuilder->create<BackwardDifferentiableAttribute>()); if (synGeneric) @@ -5231,49 +5284,49 @@ namespace Slang switch (pattern) { - case SynthesisPattern::AllInductive: + case SynthesisPattern::AllInductive: + { + for (auto arg : synArgs) + { + auto memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = arg; + + memberExpr->name = derivMemberName; + + paramFields.add(memberExpr); + inductiveArgMask.add(true); + } + break; + } + case SynthesisPattern::FixedFirstArg: + { + int paramIndex = 0; + for (auto arg : synArgs) { - for (auto arg : synArgs) + if (paramIndex == 0) + { + paramFields.add(arg); + inductiveArgMask.add(false); + + paramIndex++; + } + else { auto memberExpr = m_astBuilder->create<MemberExpr>(); memberExpr->baseExpression = arg; memberExpr->name = derivMemberName; - paramFields.add(memberExpr); inductiveArgMask.add(true); - } - break; - } - case SynthesisPattern::FixedFirstArg: - { - int paramIndex = 0; - for (auto arg : synArgs) - { - if (paramIndex == 0) - { - paramFields.add(arg); - inductiveArgMask.add(false); - - paramIndex++; - } - else - { - auto memberExpr = m_astBuilder->create<MemberExpr>(); - memberExpr->baseExpression = arg; - - memberExpr->name = derivMemberName; - paramFields.add(memberExpr); - inductiveArgMask.add(true); - paramIndex++; - } + paramIndex++; } - break; } - default: - SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern"); - break; + break; + } + default: + SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern"); + break; } // Invoke the method for the field and assign the value to resultVar. @@ -5294,8 +5347,9 @@ namespace Slang auto synReturn = m_astBuilder->create<ReturnStmt>(); synReturn->expression = resultVarExpr; seqStmt->stmts.add(synReturn); - - context->parentDecl->members.add(synFunc); + + Decl* witnessDecl = synGeneric ? (Decl*)synGeneric : synFunc; + context->parentDecl->members.add(witnessDecl); context->parentDecl->invalidateMemberDictionary(); addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>()); @@ -5313,21 +5367,29 @@ namespace Slang substSet = SubstitutionSet(declRefType->getDeclRef()); } } - if (auto outerGeneric = GetOuterGeneric(context->parentDecl)) + if (!substSet.declRef) + return false; + DeclRef<Decl> synthesizedWitnessDeclRef; + if (auto parentExtDecl = as<ExtensionDecl>(context->parentDecl)) { - // If the context->parentDecl is not the same as ThisType represented by genApp, then it must be an extension - // to ThisType. In this case, we need to form a new GenericAppDeclRef to specailizethe outer parent extension - // decl. Note that the extension might be a partial extension with some generic arguments missing, and - // we can't support that case right now. For now we can just assume the extension will have the same set - // of generic parameters as the target type. - auto defaultArgs = getDefaultSubstitutionArgs(m_astBuilder, this, outerGeneric); - auto specializedParent = m_astBuilder->getGenericAppDeclRef(makeDeclRef(outerGeneric), defaultArgs.getArrayView()); - auto specializedFunc = m_astBuilder->getMemberDeclRef(specializedParent, synFunc); - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(specializedFunc)); - return true; - } + // If the conformance is declared on an extension to ThisType, + // we need to form a new proper decl ref to the parent extension decl + // with the correct specialization arguments. + // + if (GetOuterGeneric(context->parentDecl)) + { - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(m_astBuilder->getDirectDeclRef(synFunc))); + auto extDeclRef = applyExtensionToType(parentExtDecl, context->conformingType); + synthesizedWitnessDeclRef = m_astBuilder->getMemberDeclRef(extDeclRef, witnessDecl); + } + } + else + { + synthesizedWitnessDeclRef = m_astBuilder->getMemberDeclRef(substSet.declRef, witnessDecl); + } + if (!synthesizedWitnessDeclRef) + synthesizedWitnessDeclRef = m_astBuilder->getDirectDeclRef(witnessDecl); + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(synthesizedWitnessDeclRef)); return true; } @@ -5351,6 +5413,15 @@ namespace Slang // with the same name in the type declaration and // its (known) extensions. + // The exception to that is when the requiredMemberDeclRef is already + // resolved to the actual satisfying decl, in which case we simply return + // true without any further lookup. + if (!as<InterfaceDecl>(requiredMemberDeclRef.getParent().getDecl())) + return true; + + // If `requiredMemberDeclRef` is a lookup decl ref for an interface requirement + // we attempt to do the loopkup through witness tables. + // // As a first pass, lets check if we already have a // witness in the table for the requirement, so // that we can bail out early. @@ -5655,7 +5726,7 @@ namespace Slang subType, superInterfaceType, inheritanceDecl, - thisTypeDeclRef, + superInterfaceDeclRef, requiredMemberDeclRef, witnessTable, subTypeConformsToSuperInterfaceWitness); @@ -5674,7 +5745,7 @@ namespace Slang subType, superInterfaceType, inheritanceDecl, - thisTypeDeclRef, + superInterfaceDeclRef, requiredMemberDeclRef, witnessTable, subTypeConformsToSuperInterfaceWitness); @@ -5726,7 +5797,7 @@ namespace Slang subType, superInterfaceType, inheritanceDecl, - thisTypeDeclRef, + superInterfaceDeclRef, requiredInheritanceDeclRef, witnessTable, subTypeConformsToSuperInterfaceWitness); @@ -8551,7 +8622,9 @@ namespace Slang // Looks like we have a match in the types, // now let's see if `type`'s declref starts with a Lookup. targetType = type; - extDeclRef = m_astBuilder->getLookupDeclRef(thisTypeLookupDeclRef->getWitness(), extDeclRef.getDecl()); + extDeclRef = m_astBuilder->getLookupDeclRef( + thisTypeLookupDeclRef->getWitness(), extDeclRef.getDecl()) + .as<ExtensionDecl>(); } } } |
