diff options
| author | Yong He <yonghe@outlook.com> | 2018-01-14 16:22:11 -0500 |
|---|---|---|
| committer | Yong He <yonghe@outlook.com> | 2018-01-14 16:22:11 -0500 |
| commit | d33e6b7475a87d5a62101afc81813e9c9e458a70 (patch) | |
| tree | 7ed6bc875ee0872f35218e7c76d18bf5bcad02ec | |
| parent | d4dab2cd3a409411c2d7caed01fc02a0fd3e8450 (diff) | |
allow extension of a concrete type to implement additional interface
Also support the scenario that the extension declares conformance to interface I, and a method M in I is already supported by the base implementation.
| -rw-r--r-- | source/slang/check.cpp | 39 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 25 | ||||
| -rw-r--r-- | source/slang/syntax.h | 52 | ||||
| -rw-r--r-- | tests/compute/extension-multi-interface.slang | 49 | ||||
| -rw-r--r-- | tests/compute/extension-multi-interface.slang.expected.txt | 4 | ||||
| -rw-r--r-- | tests/compute/multi-interface.slang | 45 | ||||
| -rw-r--r-- | tests/compute/multi-interface.slang.expected.txt | 4 |
7 files changed, 189 insertions, 29 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 6c484b493..6b8331060 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -1794,7 +1794,7 @@ namespace Slang // `requiredMemberDeclRef` is a required member of // the interface. RefPtr<Decl> findWitnessForInterfaceRequirement( - DeclRef<AggTypeDecl> typeDeclRef, + DeclRef<AggTypeDeclBase> typeDeclRef, InheritanceDecl* inheritanceDecl, DeclRef<InterfaceDecl> interfaceDeclRef, DeclRef<Decl> requiredMemberDeclRef, @@ -1833,11 +1833,9 @@ namespace Slang // Make sure that by-name lookup is possible. buildMemberDictionary(typeDeclRef.getDecl()); - - Decl* firstMemberOfName = nullptr; - typeDeclRef.getDecl()->memberDictionary.TryGetValue(name, firstMemberOfName); - - if (!firstMemberOfName) + auto lookupResult = lookUpLocal(getSession(), this, name, typeDeclRef); + + if (!lookupResult.isValid()) { getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef); return nullptr; @@ -1845,10 +1843,10 @@ namespace Slang // Iterate over the members and look for one that matches // the expected signature for the requirement. - for (auto memberDecl = firstMemberOfName; memberDecl; memberDecl = memberDecl->nextInContainerWithSameName) + for (auto member : lookupResult) { - if (doesMemberSatisfyRequirement(DeclRef<Decl>(memberDecl, typeDeclRef.substitutions), requiredMemberDeclRef, requirementWitness)) - return memberDecl; + if (doesMemberSatisfyRequirement(member.declRef, requiredMemberDeclRef, requirementWitness)) + return member.declRef.getDecl(); } // No suitable member found, although there were candidates. @@ -1867,7 +1865,7 @@ namespace Slang // (via the given `inheritanceDecl`) actually provides // members to satisfy all the requirements in the interface. bool checkInterfaceConformance( - DeclRef<AggTypeDecl> typeDeclRef, + DeclRef<AggTypeDeclBase> typeDeclRef, InheritanceDecl* inheritanceDecl, DeclRef<InterfaceDecl> interfaceDeclRef) { @@ -1925,7 +1923,7 @@ namespace Slang } bool checkConformanceToType( - DeclRef<AggTypeDecl> typeDeclRef, + DeclRef<AggTypeDeclBase> typeDeclRef, InheritanceDecl* inheritanceDecl, Type* baseType) { @@ -1953,7 +1951,7 @@ namespace Slang // `inheritanceDecl` actually does what it needs to // for that inheritance to be valid. bool checkConformance( - DeclRef<AggTypeDecl> typeDecl, + DeclRef<AggTypeDeclBase> typeDecl, InheritanceDecl* inheritanceDecl) { // Look at the type being inherited from, and validate @@ -1963,10 +1961,10 @@ namespace Slang } bool checkConformance( - AggTypeDecl* typeDecl, + AggTypeDeclBase* typeDecl, InheritanceDecl* inheritanceDecl) { - return checkConformance(DeclRef<AggTypeDecl>(typeDecl, SubstitutionSet()), inheritanceDecl); + return checkConformance(DeclRef<AggTypeDeclBase>(typeDecl, SubstitutionSet()), inheritanceDecl); } void visitAggTypeDecl(AggTypeDecl* decl) @@ -3479,10 +3477,11 @@ namespace Slang // TODO: need to check that the target type names a declaration... + DeclRef<AggTypeDecl> aggTypeDeclRef; if (auto targetDeclRefType = decl->targetType->As<DeclRefType>()) { // Attach our extension to that type as a candidate... - if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>()) + if (aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>()) { auto aggTypeDecl = aggTypeDeclRef.getDecl(); decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; @@ -3516,6 +3515,14 @@ namespace Slang EnsureDecl(m); } + if (aggTypeDeclRef) + { + for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>()) + { + checkConformance(aggTypeDeclRef.getDecl(), inheritanceDecl); + } + } + decl->SetCheckState(DeclCheckState::Checked); } @@ -3802,7 +3809,7 @@ namespace Slang if( auto aggTypeDeclRef = declRef.As<AggTypeDecl>() ) { - for( auto inheritanceDeclRef : getMembersOfType<InheritanceDecl>(aggTypeDeclRef)) + for( auto inheritanceDeclRef : getMembersOfTypeWithExt<InheritanceDecl>(aggTypeDeclRef)) { EnsureDecl(inheritanceDeclRef.getDecl()); diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 61ca53278..498783f4b 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -2788,19 +2788,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // TODO: if this inheritance declaration is under an extension, // then we should construct the type that is being extended, // and not a reference to the extension itself. - auto parentDecl = inheritanceDecl->ParentDecl; - RefPtr<Type> type = DeclRefType::Create( - context->getSession(), - makeDeclRef(parentDecl)); + auto parentDecl = inheritanceDecl->ParentDecl; + RefPtr<Type> type; + if (auto extParentDecl = dynamic_cast<ExtensionDecl*>(parentDecl)) + { + type = extParentDecl->targetType.type; + if (auto declRefType = type.As<DeclRefType>()) + { + if (auto aggTypeDecl = declRefType->declRef.As<AggTypeDecl>()) + parentDecl = aggTypeDecl.getDecl(); + } + } + else + { + type = DeclRefType::Create( + context->getSession(), + makeDeclRef(parentDecl)); + } // What is the super-type that we have declared we inherit from? RefPtr<Type> superType = inheritanceDecl->base.type; // Construct the mangled name for the witness table, which depends // on the type that is conforming, and the type that it conforms to. - String mangledName = getMangledNameForConformanceWitness( - makeDeclRef(parentDecl), - superType); + String mangledName = getMangledNameForConformanceWitness(type, superType); // Build an IR level witness table, which will represent the // conformance of the type to its super-type. diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 375eb5f1c..ab26b1f6d 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -978,6 +978,30 @@ namespace Slang { return items.Count() > 1 ? items[0].declRef.GetName() : item.declRef.GetName(); } + LookupResultItem* begin() + { + if (isValid()) + { + if (isOverloaded()) + return items.begin(); + else + return &item; + } + else + return nullptr; + } + LookupResultItem* end() + { + if (isValid()) + { + if (isOverloaded()) + return items.end(); + else + return &item + 1; + } + else + return nullptr; + } }; struct SemanticsVisitor; @@ -1085,6 +1109,27 @@ namespace Slang return FilteredMemberRefList<T>(declRef.getDecl()->Members, declRef.substitutions); } + inline ExtensionDecl* GetCandidateExtensions(DeclRef<AggTypeDecl> const& declRef) + { + return declRef.getDecl()->candidateExtensions; + } + + template<typename T> + inline FilteredMemberRefList<T> getMembersOfTypeWithExt(DeclRef<ContainerDecl> const& declRef) + { + auto rs = getMembersOfType<T>(declRef); + if (auto aggDeclRef = declRef.As<AggTypeDecl>()) + { + for (auto ext = GetCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension) + { + auto extMembers = getMembersOfType<T>(DeclRef<ContainerDecl>(ext, declRef.substitutions)); + const_cast<List<RefPtr<Decl>>&>(rs.decls).AddRange(extMembers.decls); + } + } + return rs; + } + + inline RefPtr<Type> GetType(DeclRef<VarDeclBase> const& declRef) { return declRef.Substitute(declRef.getDecl()->type.Ptr()); @@ -1099,12 +1144,7 @@ namespace Slang { return declRef.Substitute(declRef.getDecl()->targetType.Ptr()); } - - inline ExtensionDecl* GetCandidateExtensions(DeclRef<AggTypeDecl> const& declRef) - { - return declRef.getDecl()->candidateExtensions; - } - + inline FilteredMemberRefList<StructField> GetFields(DeclRef<StructDecl> const& declRef) { return getMembersOfType<StructField>(declRef); diff --git a/tests/compute/extension-multi-interface.slang b/tests/compute/extension-multi-interface.slang new file mode 100644 index 000000000..6cc88f87c --- /dev/null +++ b/tests/compute/extension-multi-interface.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer<float> outputBuffer; + +interface IAdd +{ + float addf(float u, float v); +} + +interface ISub +{ + float subf(float u, float v); +} + +interface IAddAndSub +{ + float addf(float u, float v); + float subf(float u, float v); +} + +struct Simple : IAdd +{ + float addf(float u, float v) + { + return u+v; + } +}; + +__extension Simple : ISub, IAddAndSub +{ + float subf(float u, float v) + { + return u-v; + } +}; + +float testAddSub<T:IAddAndSub>(T t) +{ + return t.subf(t.addf(1.0, 1.0), 1.0); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + Simple s; + float outVal = testAddSub(s); + outputBuffer[dispatchThreadID.x] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/extension-multi-interface.slang.expected.txt b/tests/compute/extension-multi-interface.slang.expected.txt new file mode 100644 index 000000000..e143b7f20 --- /dev/null +++ b/tests/compute/extension-multi-interface.slang.expected.txt @@ -0,0 +1,4 @@ +3F800000 +3F800000 +3F800000 +3F800000
\ No newline at end of file diff --git a/tests/compute/multi-interface.slang b/tests/compute/multi-interface.slang new file mode 100644 index 000000000..1f9775211 --- /dev/null +++ b/tests/compute/multi-interface.slang @@ -0,0 +1,45 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer<float> outputBuffer; + +interface IAdd +{ + float addf(float u, float v); +} + +interface ISub +{ + float subf(float u, float v); +} + +interface IAddAndSub +{ + float addf(float u, float v); + float subf(float u, float v); +} + +struct Simple : IAdd, ISub, IAddAndSub +{ + float addf(float u, float v) + { + return u+v; + } + float subf(float u, float v) + { + return u-v; + } +}; + +float testAddSub<T:IAddAndSub>(T t) +{ + return t.subf(t.addf(1.0, 1.0), 1.0); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + Simple s; + float outVal = testAddSub(s); + outputBuffer[dispatchThreadID.x] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/multi-interface.slang.expected.txt b/tests/compute/multi-interface.slang.expected.txt new file mode 100644 index 000000000..e143b7f20 --- /dev/null +++ b/tests/compute/multi-interface.slang.expected.txt @@ -0,0 +1,4 @@ +3F800000 +3F800000 +3F800000 +3F800000
\ No newline at end of file |
