diff options
| author | Yong He <yonghe@outlook.com> | 2018-01-09 10:50:44 -0800 |
|---|---|---|
| committer | Tim Foley <tfoleyNV@users.noreply.github.com> | 2018-01-09 10:50:44 -0800 |
| commit | 8daafcc2e4bf7b2dfb66d7a3b7ac60c86b2d926c (patch) | |
| tree | b7fac301e3c4d1b006af70584feeb45af191aab6 | |
| parent | 3d435f7321c3f9241d33a0f7521573f21b548186 (diff) | |
bruteforce implementation of witness table resolution for associated (#358)
| -rw-r--r-- | source/slang/check.cpp | 221 | ||||
| -rw-r--r-- | source/slang/decl-defs.h | 29 | ||||
| -rw-r--r-- | source/slang/ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/ir-insts.h | 24 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 78 | ||||
| -rw-r--r-- | source/slang/lookup.cpp | 2 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 64 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 2 | ||||
| -rw-r--r-- | source/slang/syntax-base-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 80 | ||||
| -rw-r--r-- | source/slang/syntax.h | 7 | ||||
| -rw-r--r-- | source/slang/val-defs.h | 11 | ||||
| -rw-r--r-- | tests/compute/assoctype-generic-arg.slang | 38 | ||||
| -rw-r--r-- | tests/compute/assoctype-generic-arg.slang.expected.txt | 4 | ||||
| -rw-r--r-- | tests/compute/assoctype-simple.slang | 6 | ||||
| -rw-r--r-- | tools/render-test/test.txt | 1 |
16 files changed, 421 insertions, 149 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index ca9e0e7e5..77f9c8f44 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -153,6 +153,8 @@ namespace Slang RefPtr<Expr> baseExpr, SourceLoc loc) { + if (declRef.As<AssocTypeDecl>()) + getNewThisTypeSubst(declRef); if (baseExpr) { RefPtr<Expr> expr; @@ -182,26 +184,18 @@ namespace Slang if (auto baseDeclRefExpr = baseExpr->As<DeclRefExpr>()) { baseThisTypeSubst = getThisTypeSubst(baseDeclRefExpr->declRef, false); - if (auto baseAssocType = baseDeclRefExpr->declRef.As<AssocTypeDecl>()) - { - baseThisTypeSubst = new ThisTypeSubstitution(); - baseThisTypeSubst->sourceType = baseDeclRefExpr->type.type; - if (auto typetype = baseThisTypeSubst->sourceType.As<TypeType>()) - baseThisTypeSubst->sourceType = typetype->type; - } } - if (auto assocTypeDecl = declRef.As<AssocTypeDecl>()) + if (declRef.As<TypeConstraintDecl>()) { - auto newThisTypeSubst = new ThisTypeSubstitution(); - if (baseThisTypeSubst) - newThisTypeSubst->sourceType = baseThisTypeSubst->sourceType; - expr->type = GetTypeForDeclRef(DeclRef<AssocTypeDecl>(assocTypeDecl.getDecl(), newThisTypeSubst)); - auto declOutThisTypeSubst = getNewThisTypeSubst(*declRefOut); - if (baseThisTypeSubst) - declOutThisTypeSubst->sourceType = baseThisTypeSubst->sourceType; - return expr; + // if this is a reference to type constraint, insert a this-type substitution + RefPtr<Type> expType; + expType = baseExpr->type; + if (auto baseExprTT = baseExpr->type->As<TypeType>()) + expType = baseExprTT->type; + auto thisTypeSubst = getNewThisTypeSubst(*declRefOut); + thisTypeSubst->sourceType = expType; + baseThisTypeSubst = nullptr; } - // propagate "this-type" substitutions if (baseThisTypeSubst) { @@ -210,7 +204,7 @@ namespace Slang getNewThisTypeSubst(declRefExpr->declRef)->sourceType = baseThisTypeSubst->sourceType; } } - expr->type = GetTypeForDeclRef(declRef); + expr->type = GetTypeForDeclRef(*declRefOut); return expr; } else @@ -219,21 +213,6 @@ namespace Slang expr->loc = loc; expr->name = declRef.GetName(); expr->type = GetTypeForDeclRef(declRef); - if (auto exprDeclRefType = getExprDeclRefType(expr)) - { - if (auto genParmDecl = exprDeclRefType->declRef.As<GenericTypeParamDecl>()) - { - // if this is a reference to generic type param, insert a this-type substitution - auto exprType = GetTypeForDeclRef(declRef); - auto thisSubst = new ThisTypeSubstitution(); - if (auto typetype = exprType.type.As<TypeType>()) - thisSubst->sourceType = typetype->type; - else - thisSubst->sourceType = exprType.type; - thisSubst->outer = declRef.substitutions; - declRef.substitutions = thisSubst; - } - } expr->declRef = declRef; return expr; } @@ -1614,27 +1593,27 @@ namespace Slang } bool doesSignatureMatchRequirement( - CallableDecl* /*memberDecl*/, - DeclRef<CallableDecl> requiredMemberDeclRef) + DeclRef<CallableDecl> memberDecl, + DeclRef<CallableDecl> requiredMemberDeclRef, + Dictionary<DeclRef<Decl>, DeclRef<Decl>> & requirementDict) { // TODO: actually implement matching here. For now we'll - // just pretend that things are satisfied in order to make progress. + // just pretend that things are satisfied in order to make progress.. + requirementDict.Add(requiredMemberDeclRef, DeclRef<Decl>(memberDecl, nullptr)); return true; } bool doesGenericSignatureMatchRequirement( - GenericDecl * genDecl, - DeclRef<GenericDecl> requirementGenDecl) + DeclRef<GenericDecl> genDecl, + DeclRef<GenericDecl> requirementGenDecl, + Dictionary<DeclRef<Decl>, DeclRef<Decl>> & requirementDict) { - // TODO: genDecl should be a DeclRef to capture the environment generic variables needed to get - // a concrete type for a generic constraint super type (e.g. when this member belongs to a generic type) - - if (genDecl->Members.Count() != requirementGenDecl.getDecl()->Members.Count()) + if (genDecl.getDecl()->Members.Count() != requirementGenDecl.getDecl()->Members.Count()) return false; - for (UInt i = 0; i < genDecl->Members.Count(); i++) + for (UInt i = 0; i < genDecl.getDecl()->Members.Count(); i++) { - auto genMbr = genDecl->Members[i]; - auto requiredGenMbr = genDecl->Members[i]; + auto genMbr = genDecl.getDecl()->Members[i]; + auto requiredGenMbr = genDecl.getDecl()->Members[i]; if (auto genTypeMbr = genMbr.As<GenericTypeParamDecl>()) { if (auto requiredGenTypeMbr = requiredGenMbr.As<GenericTypeParamDecl>()) @@ -1647,7 +1626,8 @@ namespace Slang { if (auto requiredGenValMbr = requiredGenMbr.As<GenericValueParamDecl>()) { - return genValMbr->type->Equals(requiredGenValMbr->type); + if (!genValMbr->type->Equals(requiredGenValMbr->type)) + return false; } else return false; @@ -1665,16 +1645,18 @@ namespace Slang return false; } } - return doesMemberSatisfyRequirement(genDecl->inner.Ptr(), - DeclRef<Decl>(requirementGenDecl.getDecl()->inner.Ptr(), requirementGenDecl.substitutions)); + return doesMemberSatisfyRequirement(DeclRef<Decl>(genDecl.getDecl()->inner.Ptr(), genDecl.substitutions), + DeclRef<Decl>(requirementGenDecl.getDecl()->inner.Ptr(), requirementGenDecl.substitutions), + requirementDict); } // Does the given `memberDecl` work as an implementation // to satisfy the requirement `requiredMemberDeclRef` // from an interface? bool doesMemberSatisfyRequirement( - Decl* memberDecl, - DeclRef<Decl> requiredMemberDeclRef) + DeclRef<Decl> memberDeclRef, + DeclRef<Decl> requiredMemberDeclRef, + Dictionary<DeclRef<Decl>, DeclRef<Decl>> & requirementDictionary) { // At a high level, we want to chack that the // `memberDecl` and the `requiredMemberDeclRef` @@ -1694,74 +1676,85 @@ namespace Slang // An associated type requirement should be allowed // to be satisfied by any type declaration: // a typedef, a `struct`, etc. - - if (auto memberFuncDecl = dynamic_cast<FuncDecl*>(memberDecl)) + auto checkSubTypeMember = [&](DeclRef<AggTypeDecl> subStructTypeDeclRef) -> bool + { + EnsureDecl(subStructTypeDeclRef.getDecl()); + // this is a sub type (e.g. nested struct declaration) in an aggregate type + // check if this sub type declaration satisfies the constraints defined by the associated type + if (auto requiredTypeDeclRef = requiredMemberDeclRef.As<AssocTypeDecl>()) + { + bool conformance = true; + auto inheritanceReqDeclRefs = getMembersOfType<InheritanceDecl>(requiredTypeDeclRef); + for (auto inheritanceReqDeclRef : inheritanceReqDeclRefs) + { + auto interfaceDeclRefType = inheritanceReqDeclRef.getDecl()->base.type.As<DeclRefType>(); + SLANG_ASSERT(interfaceDeclRefType); + auto interfaceDeclRef = interfaceDeclRefType->declRef.As<InterfaceDecl>(); + SLANG_ASSERT(interfaceDeclRef); + RefPtr<DeclRefType> declRefType = new DeclRefType(); + declRefType->declRef = subStructTypeDeclRef; + auto witness = tryGetInterfaceConformanceWitness(declRefType, + interfaceDeclRef).As<SubtypeWitness>(); + if (witness) + requirementDictionary.Add(inheritanceReqDeclRef, witness->getLastStepDeclRef()); + else + conformance = false; + } + return conformance; + } + return false; + }; + if (auto memberFuncDecl = memberDeclRef.As<FuncDecl>()) { if (auto requiredFuncDeclRef = requiredMemberDeclRef.As<FuncDecl>()) { // Check signature match. return doesSignatureMatchRequirement( memberFuncDecl, - requiredFuncDeclRef); + requiredFuncDeclRef, + requirementDictionary); } } - else if (auto memberInitDecl = dynamic_cast<ConstructorDecl*>(memberDecl)) + else if (auto memberInitDecl = memberDeclRef.As<ConstructorDecl>()) { if (auto requiredInitDecl = requiredMemberDeclRef.As<ConstructorDecl>()) { // Check signature match. return doesSignatureMatchRequirement( memberInitDecl, - requiredInitDecl); + requiredInitDecl, + requirementDictionary); } } - else if (auto genDecl = dynamic_cast<GenericDecl*>(memberDecl)) + else if (auto genDecl = memberDeclRef.As<GenericDecl>()) { if (auto requiredGenDeclRef = requiredMemberDeclRef.As<GenericDecl>()) { - return doesGenericSignatureMatchRequirement(genDecl, requiredGenDeclRef); + return doesGenericSignatureMatchRequirement(genDecl, requiredGenDeclRef, requirementDictionary); } } - else if (auto subStructTypeDecl = dynamic_cast<AggTypeDecl*>(memberDecl)) + else if (auto subStructTypeDeclRef = memberDeclRef.As<AggTypeDecl>()) { - // this is a sub type (e.g. nested struct declaration) in an aggregate type - // check if this sub type declaration satisfies the constraints defined by the associated type - if (auto requiredTypeDeclRef = requiredMemberDeclRef.As<AssocTypeDecl>()) - { - bool conformance = true; - for (auto & inheritanceDecl : requiredTypeDeclRef.getDecl()->getMembersOfType<InheritanceDecl>()) - { - conformance = conformance && checkConformance(subStructTypeDecl, inheritanceDecl.Ptr()); - } - return conformance; - } + return checkSubTypeMember(subStructTypeDeclRef); } - else if (auto typedefDecl = dynamic_cast<TypeDefDecl*>(memberDecl)) + else if (auto typedefDeclRef = memberDeclRef.As<TypeDefDecl>()) { // this is a type-def decl in an aggregate type // check if the specified type satisfies the constraints defined by the associated type - if (auto requiredTypeDecl = requiredMemberDeclRef.As<AssocTypeDecl>()) + if (auto requiredTypeDeclRef = requiredMemberDeclRef.As<AssocTypeDecl>()) { - auto constraintList = requiredTypeDecl.getDecl()->getMembersOfType<InheritanceDecl>(); + auto constraintList = getMembersOfType<InheritanceDecl>(requiredTypeDeclRef); if (constraintList.Count()) { - auto declRefType = typedefDecl->type->AsDeclRefType(); + auto declRefType = GetType(typedefDeclRef)->GetCanonicalType()->As<DeclRefType>(); if (!declRefType) return false; - auto structTypeDecl = declRefType->declRef.getDecl()->As<AggTypeDecl>(); - if (!structTypeDecl) + auto structTypeDeclRef = declRefType->declRef.As<AggTypeDecl>(); + if (!structTypeDeclRef) return false; - //TODO: What do we do if type is a generic specialization? - // i.e. if the struct defines typedef Generic<float> T; - // how to check if T satisfies the associatedtype constraints? - // the code below will only work when T is defined to be a simple aggregated type (no generics). - bool conformance = true; - for (auto & inheritanceDecl : constraintList) - { - conformance = conformance && checkConformance(structTypeDecl, inheritanceDecl.Ptr()); - } - return conformance; + + return checkSubTypeMember(structTypeDeclRef); } return true; } @@ -1779,10 +1772,11 @@ namespace Slang // `requiredMemberDeclRef` is a required member of // the interface. RefPtr<Decl> findWitnessForInterfaceRequirement( - AggTypeDecl* typeDecl, + DeclRef<AggTypeDecl> typeDeclRef, InheritanceDecl* inheritanceDecl, DeclRef<InterfaceDecl> interfaceDeclRef, - DeclRef<Decl> requiredMemberDeclRef) + DeclRef<Decl> requiredMemberDeclRef, + Dictionary<DeclRef<Decl>, DeclRef<Decl>> & requirementWitness) { // We will look up members with the same name, // since only same-name members will be able to @@ -1816,14 +1810,14 @@ namespace Slang // now, so we won't worry about this. // Make sure that by-name lookup is possible. - buildMemberDictionary(typeDecl); + buildMemberDictionary(typeDeclRef.getDecl()); Decl* firstMemberOfName = nullptr; - typeDecl->memberDictionary.TryGetValue(name, firstMemberOfName); + typeDeclRef.getDecl()->memberDictionary.TryGetValue(name, firstMemberOfName); if (!firstMemberOfName) { - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDecl, requiredMemberDeclRef); + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef); return nullptr; } @@ -1831,7 +1825,7 @@ namespace Slang // the expected signature for the requirement. for (auto memberDecl = firstMemberOfName; memberDecl; memberDecl = memberDecl->nextInContainerWithSameName) { - if (doesMemberSatisfyRequirement(memberDecl, requiredMemberDeclRef)) + if (doesMemberSatisfyRequirement(DeclRef<Decl>(memberDecl, typeDeclRef.substitutions), requiredMemberDeclRef, requirementWitness)) return memberDecl; } @@ -1842,7 +1836,7 @@ namespace Slang // of "candidates" for satisfaction of the requirement, // and if nothing is found we print the candidates - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDecl, requiredMemberDeclRef); + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef); return nullptr; } @@ -1851,7 +1845,7 @@ namespace Slang // (via the given `inheritanceDecl`) actually provides // members to satisfy all the requirements in the interface. bool checkInterfaceConformance( - AggTypeDecl* typeDecl, + DeclRef<AggTypeDecl> typeDeclRef, InheritanceDecl* inheritanceDecl, DeclRef<InterfaceDecl> interfaceDeclRef) { @@ -1884,7 +1878,7 @@ namespace Slang // // TODO: we *really* need a linearization step here!!!! result = result && checkConformanceToType( - typeDecl, + typeDeclRef, inheritanceDecl, getBaseType(requiredInheritanceDeclRef)); continue; @@ -1892,30 +1886,26 @@ namespace Slang // Look for a member in the type that can satisfy the // interface requirement. - auto conformanceWitness = findWitnessForInterfaceRequirement( - typeDecl, + auto isConformanceSatisfied = findWitnessForInterfaceRequirement( + typeDeclRef, inheritanceDecl, interfaceDeclRef, - requiredMemberDeclRef); + requiredMemberDeclRef, + inheritanceDecl->requirementWitnesses); - if (!conformanceWitness) + if (!isConformanceSatisfied) { result = false; continue; } - - // Store that witness into a table stored on the `inheritnaceDecl` - // so that it can be used for downstream code generation. - - inheritanceDecl->requirementWitnesses.Add(requiredMemberDeclRef, conformanceWitness); } return result; } bool checkConformanceToType( - AggTypeDecl* typeDecl, - InheritanceDecl* inheritanceDecl, - Type* baseType) + DeclRef<AggTypeDecl> typeDeclRef, + InheritanceDecl* inheritanceDecl, + Type* baseType) { if (auto baseDeclRefType = baseType->As<DeclRefType>()) { @@ -1926,7 +1916,7 @@ namespace Slang // We need to check that it provides all of the members // required by that interface. return checkInterfaceConformance( - typeDecl, + typeDeclRef, inheritanceDecl, baseInterfaceDeclRef); } @@ -1941,13 +1931,20 @@ namespace Slang // `inheritanceDecl` actually does what it needs to // for that inheritance to be valid. bool checkConformance( - AggTypeDecl* typeDecl, - InheritanceDecl* inheritanceDecl) + DeclRef<AggTypeDecl> typeDecl, + InheritanceDecl* inheritanceDecl) { // Look at the type being inherited from, and validate // appropriately. auto baseType = inheritanceDecl->base.type; - return checkConformanceToType(typeDecl, inheritanceDecl, baseType); + return checkConformanceToType(typeDecl, inheritanceDecl, baseType.As<Type>()); + } + + bool checkConformance( + AggTypeDecl* typeDecl, + InheritanceDecl* inheritanceDecl) + { + return checkConformance(DeclRef<AggTypeDecl>(typeDecl, nullptr), inheritanceDecl); } void visitAggTypeDecl(AggTypeDecl* decl) @@ -5095,7 +5092,7 @@ namespace Slang RefPtr<Substitutions> snd) { // They must both be NULL or non-NULL - if (!fst || !snd) + if (!hasGenericSubstitutions(fst) || !hasGenericSubstitutions(snd)) return fst == snd; auto fstGen = fst.As<GenericSubstitution>(); auto sndGen = snd.As<GenericSubstitution>(); @@ -6927,7 +6924,7 @@ namespace Slang auto type = getFuncType(session, funcDeclRef); return QualType(type); } - else if (auto constraintDeclRef = declRef.As<GenericTypeConstraintDecl>()) + else if (auto constraintDeclRef = declRef.As<TypeConstraintDecl>()) { // When we access a constraint or an inheritance decl (as a member), // we are conceptually performing a "cast" to the given super-type, diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h index 7cb9ffc0f..fb35e327a 100644 --- a/source/slang/decl-defs.h +++ b/source/slang/decl-defs.h @@ -90,21 +90,30 @@ SIMPLE_SYNTAX_CLASS(ClassDecl, AggTypeDecl) // An interface which other types can conform to SIMPLE_SYNTAX_CLASS(InterfaceDecl, AggTypeDecl) +ABSTRACT_SYNTAX_CLASS(TypeConstraintDecl, Decl) + RAW( + virtual TypeExp& getSup() = 0; + ) +END_SYNTAX_CLASS() + // A kind of pseudo-member that represents an explicit // or implicit inheritance relationship. // -SYNTAX_CLASS(InheritanceDecl, Decl) - // The type expression as written +SYNTAX_CLASS(InheritanceDecl, TypeConstraintDecl) +// The type expression as written SYNTAX_FIELD(TypeExp, base) -RAW( + RAW( // After checking, this dictionary will map members // required by the base type to their concrete // implementations in the type that contains // this inheritance declaration. - Dictionary<DeclRef<Decl>, Decl*> requirementWitnesses; -) - + Dictionary<DeclRef<Decl>, DeclRef<Decl>> requirementWitnesses; + virtual TypeExp& getSup() override + { + return base; + } + ) END_SYNTAX_CLASS() // TODO: may eventually need sub-classes for explicit/direct vs. implicit/indirect inheritance @@ -216,13 +225,19 @@ SYNTAX_CLASS(GenericTypeParamDecl, SimpleTypeDecl) END_SYNTAX_CLASS() // A constraint placed as part of a generic declaration -SYNTAX_CLASS(GenericTypeConstraintDecl, Decl) +SYNTAX_CLASS(GenericTypeConstraintDecl, TypeConstraintDecl) // A type constraint like `T : U` is constraining `T` to be "below" `U` // on a lattice of types. This may not be a subtyping relationship // per se, but it makes sense to use that terminology here, so we // think of these fields as the sub-type and sup-ertype, respectively. SYNTAX_FIELD(TypeExp, sub) SYNTAX_FIELD(TypeExp, sup) + RAW( + virtual TypeExp& getSup() override + { + return sup; + } + ) END_SYNTAX_CLASS() SIMPLE_SYNTAX_CLASS(GenericValueParamDecl, VarDeclBase) diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index 7eafe89f7..cc5474842 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -98,6 +98,7 @@ INST(decl_ref, decl_ref, 0, 0) INST(specialize, specialize, 2, 0) INST(lookup_interface_method, lookup_interface_method, 2, 0) +INST(lookup_witness_table, lookup_witness_table, 2, 0) INST(Construct, construct, 0, 0) INST(Call, call, 1, 0) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index a8c8383e2..55ff26185 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -84,6 +84,12 @@ struct IRLookupWitnessMethod : IRInst IRUse requirementDeclRef; }; +struct IRLookupWitnessTable : IRInst +{ + IRUse sourceType; + IRUse interfaceType; +}; + // struct IRCall : IRInst @@ -309,6 +315,14 @@ struct IRWitnessTable : IRGlobalValue IRValueList<IRWitnessTableEntry> entries; }; +// An abstract witness table is a global value that +// represents an inheritance relationship that can't +// be resolved to a witness table at IR-generation time. +struct IRAbstractWitness : IRGlobalValue +{ + RefPtr<SubtypeWitness> witness; + DeclRef<Decl> subTypeDeclRef, supTypeDeclRef; +}; // Description of an instruction to be used for global value numbering @@ -402,6 +416,15 @@ struct IRBuilder DeclRef<Decl> witnessTableDeclRef, DeclRef<Decl> interfaceMethodDeclRef); + IRValue* emitLookupInterfaceMethodInst( + IRType* type, + IRValue* witnessTableVal, + DeclRef<Decl> interfaceMethodDeclRef); + + IRValue* emitFindWitnessTable( + DeclRef<Decl> baseTypeDeclRef, + IRType* interfaceType); + IRInst* emitCallInst( IRType* type, IRValue* func, @@ -424,7 +447,6 @@ struct IRBuilder IRFunc* createFunc(); IRGlobalVar* createGlobalVar( IRType* valueType); - IRWitnessTable* createWitnessTable(Dictionary<DeclRef<Decl>, Decl*> & witnesses); IRWitnessTable* createWitnessTable(); IRWitnessTableEntry* createWitnessTableEntry( IRWitnessTable* witnessTable, diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 088139953..1c06f5be9 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -678,11 +678,38 @@ namespace Slang DeclRef<Decl> interfaceMethodDeclRef) { auto witnessTableVal = getDeclRefVal(witnessTableDeclRef); - auto interfaceMethodVal = getDeclRefVal(interfaceMethodDeclRef); + DeclRef<Decl> removeSubstDeclRef = interfaceMethodDeclRef; + removeSubstDeclRef.substitutions = nullptr; + auto interfaceMethodVal = getDeclRefVal(removeSubstDeclRef); return emitLookupInterfaceMethodInst(type, witnessTableVal, interfaceMethodVal); } + IRValue* IRBuilder::emitLookupInterfaceMethodInst( + IRType* type, + IRValue* witnessTableVal, + DeclRef<Decl> interfaceMethodDeclRef) + { + DeclRef<Decl> removeSubstDeclRef = interfaceMethodDeclRef; + removeSubstDeclRef.substitutions = nullptr; + auto interfaceMethodVal = getDeclRefVal(removeSubstDeclRef); + return emitLookupInterfaceMethodInst(type, witnessTableVal, interfaceMethodVal); + } + IRValue* IRBuilder::emitFindWitnessTable( + DeclRef<Decl> baseTypeDeclRef, + IRType* interfaceType) + { + auto interfaceTypeDeclRef = interfaceType->AsDeclRefType(); + SLANG_ASSERT(interfaceTypeDeclRef); + auto inst = createInst<IRLookupWitnessTable>( + this, + kIROp_lookup_witness_table, + interfaceType, + getDeclRefVal(baseTypeDeclRef), + getDeclRefVal(interfaceTypeDeclRef->declRef)); + addInst(inst); + return inst; + } IRInst* IRBuilder::emitCallInst( IRType* type, @@ -3200,7 +3227,6 @@ namespace Slang Dictionary<String, VarLayout*> globalVarLayouts; RefPtr<GlobalGenericParamSubstitution> subst; - // Override the "maybe clone" logic so that we always clone virtual IRValue* maybeCloneValue(IRValue* originalVal) override; @@ -3228,6 +3254,7 @@ namespace Slang return val->Substitute(subst); } + IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue) { switch (originalValue->op) @@ -3261,14 +3288,17 @@ namespace Slang case kIROp_decl_ref: { IRDeclRef* od = (IRDeclRef*)originalValue; + auto newDeclRef = od->declRef; // if the declRef is one of the __generic_param decl being substituted by subst // return the substituted decl if (subst) { - if (od->declRef.getDecl() == subst->paramDecl) + int diff = 0; + newDeclRef = od->declRef.SubstituteImpl(subst, &diff); + if (newDeclRef.getDecl() == subst->paramDecl) return builder->getTypeVal(subst->actualType.As<Type>()); - else if (auto genConstraint = od->declRef.As<GenericTypeConstraintDecl>()) + else if (auto genConstraint = newDeclRef.As<GenericTypeConstraintDecl>()) { // a decl-ref to GenericTypeConstraintDecl as a result of // referencing a generic parameter type should be replaced with @@ -3288,7 +3318,7 @@ namespace Slang } } } - auto declRef = maybeCloneDeclRef(od->declRef); + auto declRef = maybeCloneDeclRef(newDeclRef); return builder->getDeclRefVal(declRef); } break; @@ -3641,6 +3671,14 @@ namespace Slang // and their instructions. cloneFunctionCommon(context, clonedFunc, originalFunc); + //// for now, clone all unreferenced witness tables + //for (auto gv = context->getOriginalModule()->getFirstGlobalValue(); + // gv; gv = gv->getNextValue()) + //{ + // if (gv->op == kIROp_witness_table) + // cloneGlobalValue(context, (IRWitnessTable*)gv); + //} + // We need to attach the layout information for // the entry point to this declaration, so that // we can use it to inform downstream code emit. @@ -4048,7 +4086,7 @@ namespace Slang globalVar = globalVar->getNextValue(); } SLANG_ASSERT(table); - table = cloneWitnessTableWithoutRegistering(context, (IRWitnessTable*)(table)); + table = cloneGlobalValue(context, (IRWitnessTable*)(table)); IRProxyVal * tableVal = new IRProxyVal(); tableVal->inst.init(nullptr, table); paramSubst->witnessTables.Add(KeyValuePair<RefPtr<Type>, RefPtr<Val>>(subtypeWitness->sup, tableVal)); @@ -4661,6 +4699,16 @@ namespace Slang sharedContext->workList.Add(func); } + // Build dictionary for witness tables + Dictionary<String, IRWitnessTable*> witnessTables; + for (auto gv = module->getFirstGlobalValue(); + gv; + gv = gv->getNextValue()) + { + if (gv->op == kIROp_witness_table) + witnessTables.AddIfNotExists(gv->mangledName, (IRWitnessTable*)gv); + } + // Now that we have our work list, we are going to // process it until it goes empty. Along the way // we may specialize a function and thus create @@ -4738,12 +4786,28 @@ namespace Slang // specialize a witness table auto originalTable = (IRWitnessTable*)genericVal; auto specWitnessTable = specializeWitnessTable(sharedContext, originalTable, specDeclRef); + witnessTables.AddIfNotExists(specWitnessTable->mangledName, specWitnessTable); specInst->replaceUsesWith(specWitnessTable); specInst->removeAndDeallocate(); } } break; - + case kIROp_lookup_witness_table: + { + // try find concrete witness table from global scope + IRLookupWitnessTable* lookupInst = (IRLookupWitnessTable*)ii; + IRWitnessTable* witnessTable = nullptr; + auto srcDeclRef = ((IRDeclRef*)lookupInst->sourceType.usedValue)->declRef; + auto interfaceDeclRef = ((IRDeclRef*)lookupInst->interfaceType.usedValue)->declRef; + auto mangledName = getMangledNameForConformanceWitness(srcDeclRef, interfaceDeclRef); + witnessTables.TryGetValue(mangledName, witnessTable); + if (witnessTable) + { + lookupInst->replaceUsesWith(witnessTable); + lookupInst->removeAndDeallocate(); + } + } + break; case kIROp_lookup_interface_method: { // We have a `lookup_interface_method` instruction, diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp index 86bef3f4d..901661667 100644 --- a/source/slang/lookup.cpp +++ b/source/slang/lookup.cpp @@ -412,7 +412,7 @@ void lookUpMemberImpl( auto declRef = declRefType->declRef; if (declRef.As<AssocTypeDecl>() || declRef.As<GlobalGenericParamDecl>()) { - for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(declRef.As<ContainerDecl>())) + for (auto constraintDeclRef : getMembersOfType<TypeConstraintDecl>(declRef.As<ContainerDecl>())) { // The super-type in the constraint (e.g., `Foo` in `T : Foo`) // will tell us a type we should use for lookup. diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 1d6da3a3b..e08498fc0 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -470,6 +470,52 @@ LoweredValInfo emitPostOp( return LoweredValInfo::ptr(argPtr); } +IRValue* findWitnessTable( + IRGenContext* context, + DeclRef<Decl> declRef); + +LoweredValInfo emitWitnessTableRef( + IRGenContext* context, + Expr* expr) +{ + if (auto mbrExpr = dynamic_cast<MemberExpr*>(expr)) + { + if (auto inheritanceDeclRef = mbrExpr->declRef.As<InheritanceDecl>()) + { + if (inheritanceDeclRef.getDecl()->ParentDecl->As<InterfaceDecl>() || inheritanceDeclRef.getDecl()->ParentDecl->As<AssocTypeDecl>()) + { + RefPtr<Type> exprType = nullptr; + if (auto tt = mbrExpr->BaseExpression->type->As<TypeType>()) + exprType = tt->type; + else + exprType = mbrExpr->BaseExpression->type; + auto declRefType = exprType->GetCanonicalType()->AsDeclRefType(); + SLANG_ASSERT(declRefType); + IRValue* witnessTableVal = nullptr; + DeclRef<Decl> srcDeclRef = declRefType->declRef; + if (!declRefType->declRef.As<AssocTypeDecl>()) + { + // if we are referring to an actual type, don't include substitution + // and generate specialize instruction + srcDeclRef.substitutions = nullptr; + } + witnessTableVal = context->irBuilder->emitFindWitnessTable(srcDeclRef, inheritanceDeclRef.getDecl()->base.type); + return maybeEmitSpecializeInst(context, LoweredValInfo::simple(witnessTableVal), declRefType->declRef); + } + else if (inheritanceDeclRef.getDecl()->ParentDecl->As<AggTypeDeclBase>()) + { + return LoweredValInfo::simple(findWitnessTable(context, inheritanceDeclRef)); + } + + } + else if (auto genConstraintDeclRef = mbrExpr->declRef.As<GenericTypeConstraintDecl>()) + { + return LoweredValInfo::simple(context->irBuilder->getDeclRefVal(genConstraintDeclRef)); + } + } + SLANG_UNEXPECTED("unknown witness table expression"); +} + // Emit a reference to a function, where we have concluded // that the original AST referenced `funcDeclRef`. The // optional expression `funcExpr` can provide additional @@ -494,7 +540,7 @@ LoweredValInfo emitFuncRef( if(auto baseMemberExpr = baseExpr.As<MemberExpr>()) { auto baseMemberDeclRef = baseMemberExpr->declRef; - if(auto baseConstraintDeclRef = baseMemberDeclRef.As<GenericTypeConstraintDecl>()) + if(auto baseConstraintDeclRef = baseMemberDeclRef.As<TypeConstraintDecl>()) { // We are calling a method "through" a generic type // parameter that was constrained to some type. @@ -505,10 +551,10 @@ LoweredValInfo emitFuncRef( // find the corresponding member on our chosen type. RefPtr<Type> type = funcExpr->type; - + auto loweredBaseWitnessTable = emitWitnessTableRef(context, baseMemberExpr); auto loweredVal = LoweredValInfo::simple(context->irBuilder->emitLookupInterfaceMethodInst( type, - baseMemberDeclRef, + loweredBaseWitnessTable.val, funcDeclRef)); return maybeEmitSpecializeInst(context, loweredVal, funcDeclRef); } @@ -1184,7 +1230,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> boundMemberInfo->declRef = callableDeclRef; return LoweredValInfo::boundMember(boundMemberInfo); } - else if(auto constraintDeclRef = declRef.As<GenericTypeConstraintDecl>()) + else if(auto constraintDeclRef = declRef.As<TypeConstraintDecl>()) { // The code is making use of a "witness" that a value of // some generic type conforms to an interface. @@ -2770,10 +2816,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> for (auto entry : inheritanceDecl->requirementWitnesses) { auto requiredMemberDeclRef = entry.Key; - auto satisfyingMemberDecl = entry.Value; - + auto satisfyingMemberDeclRef = entry.Value; + auto irRequirement = context->irBuilder->getDeclRefVal(requiredMemberDeclRef); - auto irSatisfyingVal = getSimpleVal(context, ensureDecl(context, satisfyingMemberDecl)); + IRValue* irSatisfyingVal = nullptr; + if (satisfyingMemberDeclRef.As<GenericTypeConstraintDecl>()) + irSatisfyingVal = context->irBuilder->getDeclRefVal(satisfyingMemberDeclRef); + else + irSatisfyingVal = getSimpleVal(context, ensureDecl(context, satisfyingMemberDeclRef)); context->irBuilder->createWitnessTableEntry( witnessTable, diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 662b75a2c..11e3eb159 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -2279,7 +2279,7 @@ namespace Slang auto nameToken = parser->ReadToken(TokenType::Identifier); assocTypeDecl->nameAndLoc = NameLoc(nameToken); assocTypeDecl->loc = nameToken.loc; - parseOptionalGenericConstraints(parser, assocTypeDecl); + parseOptionalInheritanceClause(parser, assocTypeDecl); parser->ReadToken(TokenType::Semicolon); return assocTypeDecl; } diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h index 1853125a8..de8b294d8 100644 --- a/source/slang/syntax-base-defs.h +++ b/source/slang/syntax-base-defs.h @@ -180,7 +180,6 @@ END_SYNTAX_CLASS() SYNTAX_CLASS(ThisTypeSubstitution, Substitutions) // The actual type that provides the lookup scope for an associated type SYNTAX_FIELD(RefPtr<Val>, sourceType) - RAW( // Apply a set of substitutions to the bindings in this substitution virtual RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff) override; @@ -193,6 +192,7 @@ SYNTAX_CLASS(ThisTypeSubstitution, Substitutions) } virtual int GetHashCode() const override { + SLANG_ASSERT(sourceType); return sourceType->GetHashCode(); } ) diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index cdc112846..4f043e0a1 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -1275,8 +1275,12 @@ void Type::accept(IValVisitor* visitor, void* extra) { if (!subst) return true; - if (subst && dynamic_cast<ThisTypeSubstitution*>(subst)) - return true; + if (auto thisTypeSubst = dynamic_cast<ThisTypeSubstitution*>(subst)) + { + if (!sourceType || !thisTypeSubst->sourceType) + return true; + return sourceType->EqualsVal(thisTypeSubst->sourceType); + } return false; } @@ -1405,19 +1409,48 @@ void Type::accept(IValVisitor* visitor, void* extra) *ioDiff += diff; } + void buildMemberDictionary(ContainerDecl* decl); + DeclRefBase DeclRefBase::SubstituteImpl(Substitutions* subst, int* ioDiff) { int diff = 0; RefPtr<Substitutions> substSubst = substituteSubstitutions(substitutions, subst, &diff); - if (!diff) - return *this; + return *this; *ioDiff += diff; DeclRefBase substDeclRef; substDeclRef.decl = decl; substDeclRef.substitutions = substSubst; + + // if this is a AssocTypeDecl, try lookup the actual associated type + if (auto assocTypeDecl = substDeclRef.decl->As<AssocTypeDecl>()) + { + auto thisSubst = getThisTypeSubst(substDeclRef, false); + if (thisSubst) + { + if (auto declRefType = thisSubst->sourceType.As<DeclRefType>()) + { + if (auto aggDeclRef = declRefType->declRef.As<StructDecl>()) + { + Decl* subTypeDecl = nullptr; + buildMemberDictionary(aggDeclRef.getDecl()); + SLANG_ASSERT(aggDeclRef.getDecl()->memberDictionaryIsValid); + aggDeclRef.getDecl()->memberDictionary.TryGetValue(assocTypeDecl->getName(), subTypeDecl); + if (auto typeDefDecl = subTypeDecl->As<TypeDefDecl>()) + { + auto t = GetType(DeclRef<TypeDefDecl>(typeDefDecl, aggDeclRef.substitutions)); + auto canonicalType = t->GetCanonicalType()->AsDeclRefType(); + SLANG_ASSERT(canonicalType); + return canonicalType->declRef; + } + SLANG_ASSERT(subTypeDecl); + return DeclRefBase(subTypeDecl, aggDeclRef.substitutions); + } + } + } + } return substDeclRef; } @@ -1428,7 +1461,7 @@ void Type::accept(IValVisitor* visitor, void* extra) if (decl != declRef.decl) return false; if (!substitutions) - return !declRef.substitutions || declRef.substitutions.As<ThisTypeSubstitution>(); + return !hasGenericSubstitutions(declRef.substitutions); if (!substitutions->Equals(declRef.substitutions.Ptr())) return false; @@ -1903,6 +1936,17 @@ void Type::accept(IValVisitor* visitor, void* extra) declRef.substitutions = substToInsert; } + ThisTypeSubstitution* findThisTypeSubst(Substitutions* subst) + { + while (subst) + { + if (auto thisSubst = dynamic_cast<ThisTypeSubstitution*>(subst)) + return thisSubst; + subst = subst->outer.Ptr(); + } + return nullptr; + } + RefPtr<ThisTypeSubstitution> getThisTypeSubst(DeclRefBase & declRef, bool insertSubstEntry) { RefPtr<ThisTypeSubstitution> thisSubst; @@ -1958,6 +2002,18 @@ void Type::accept(IValVisitor* visitor, void* extra) } } + bool hasThisTypeSubstitutions(RefPtr<Substitutions> subst) + { + auto p = subst.Ptr(); + while (p) + { + if (dynamic_cast<ThisTypeSubstitution*>(p)) + return true; + p = p->outer.Ptr(); + } + return false; + } + bool hasGenericSubstitutions(RefPtr<Substitutions> subst) { auto p = subst.Ptr(); @@ -1987,8 +2043,22 @@ void Type::accept(IValVisitor* visitor, void* extra) if (oldSubst) oldSubst = oldSubst->SubstituteImpl(subst, ioDiff); + // if oldSubst does not have ThisTypeSubst (which means `this_type` is free variable) + // and subst has a ThisTypeSubst (which means `this_type` is bound to a type), + // then copy that ThisTypeSubst over (to bind the this_type to the specified type) RefPtr<Substitutions> newSubst = oldSubst; insertGlobalGenericSubstitutions(newSubst, subst, ioDiff); + /*if (!hasThisTypeSubstitutions(oldSubst)) + { + auto thisTypeSubst = findThisTypeSubst(subst); + if (thisTypeSubst) + { + auto cpyThisTypeSubst = new ThisTypeSubstitution(); + cpyThisTypeSubst->sourceType = thisTypeSubst->sourceType; + insertSubstAtBottom(newSubst, cpyThisTypeSubst); + *ioDiff = 1; + } + }*/ return newSubst; } } diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 00e6bd6d3..a1f4ba801 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -102,7 +102,8 @@ namespace Slang Double, }; - + class Decl; + class Val; // Forward-declare all syntax classes #define SYNTAX_CLASS(NAME, BASE, ...) class NAME; @@ -991,9 +992,9 @@ namespace Slang return declRef.Substitute(declRef.getDecl()->sub.Ptr()); } - inline RefPtr<Type> GetSup(DeclRef<GenericTypeConstraintDecl> const& declRef) + inline RefPtr<Type> GetSup(DeclRef<TypeConstraintDecl> const& declRef) { - return declRef.Substitute(declRef.getDecl()->sup.Ptr()); + return declRef.Substitute(declRef.getDecl()->getSup().type); } // Note(tfoley): These logically belong to `Type`, diff --git a/source/slang/val-defs.h b/source/slang/val-defs.h index 4ecd5a51b..f0f830cd2 100644 --- a/source/slang/val-defs.h +++ b/source/slang/val-defs.h @@ -85,6 +85,9 @@ END_SYNTAX_CLASS() ABSTRACT_SYNTAX_CLASS(SubtypeWitness, Witness) FIELD(RefPtr<Type>, sub) FIELD(RefPtr<Type>, sup) + RAW( + virtual DeclRef<Decl> getLastStepDeclRef() = 0; + ) END_SYNTAX_CLASS() // A witness that one type is a subtype of another @@ -96,6 +99,10 @@ RAW( virtual String ToString() override; virtual int GetHashCode() override; virtual RefPtr<Val> SubstituteImpl(Substitutions * subst, int * ioDiff) override; + virtual DeclRef<Decl> getLastStepDeclRef() override + { + return declRef; + } ) END_SYNTAX_CLASS() @@ -111,6 +118,10 @@ RAW( virtual String ToString() override; virtual int GetHashCode() override; virtual RefPtr<Val> SubstituteImpl(Substitutions * subst, int * ioDiff) override; + virtual DeclRef<Decl> getLastStepDeclRef() override + { + return midToSup->getLastStepDeclRef(); + } ) END_SYNTAX_CLASS() diff --git a/tests/compute/assoctype-generic-arg.slang b/tests/compute/assoctype-generic-arg.slang new file mode 100644 index 000000000..78c54ec37 --- /dev/null +++ b/tests/compute/assoctype-generic-arg.slang @@ -0,0 +1,38 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out +//TEST_INPUT:type AssocImpl + + +RWStructuredBuffer<float> outputBuffer; + +interface IBase +{ + float getVal(); +}; + +interface IAssoc +{ + associatedtype TBase : IBase; +}; + +struct BaseImpl : IBase +{ + float getVal() { return 1.0; } +}; + +struct AssocImpl : IAssoc +{ + typedef BaseImpl TBase; +}; + +__generic_param T : IAssoc; + + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + T.TBase base; + float rs = base.getVal(); + outputBuffer[tid] = rs; +}
\ No newline at end of file diff --git a/tests/compute/assoctype-generic-arg.slang.expected.txt b/tests/compute/assoctype-generic-arg.slang.expected.txt new file mode 100644 index 000000000..e143b7f20 --- /dev/null +++ b/tests/compute/assoctype-generic-arg.slang.expected.txt @@ -0,0 +1,4 @@ +3F800000 +3F800000 +3F800000 +3F800000
\ No newline at end of file diff --git a/tests/compute/assoctype-simple.slang b/tests/compute/assoctype-simple.slang index 0f160c9c0..b14529064 100644 --- a/tests/compute/assoctype-simple.slang +++ b/tests/compute/assoctype-simple.slang @@ -8,13 +8,13 @@ RWStructuredBuffer<float> outputBuffer; interface ISimple { associatedtype U; - U add(U v0, U v1); + U addt(U v0, U v1); } struct Simple : ISimple { typedef float U; - U add(U v0, float v1) + U addt(U v0, float v1) { return v0 + v1; } @@ -23,7 +23,7 @@ struct Simple : ISimple __generic<T:ISimple> T.U test(T simple, T.U v0, T.U v1) { - return simple.add(v0, v1); + return simple.addt(v0, v1); } [numthreads(4, 1, 1)] diff --git a/tools/render-test/test.txt b/tools/render-test/test.txt deleted file mode 100644 index deb1c3630..000000000 --- a/tools/render-test/test.txt +++ /dev/null @@ -1 +0,0 @@ -3F800000 |
