summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2018-01-14 16:22:11 -0500
committerYong He <yonghe@outlook.com>2018-01-14 16:22:11 -0500
commitd33e6b7475a87d5a62101afc81813e9c9e458a70 (patch)
tree7ed6bc875ee0872f35218e7c76d18bf5bcad02ec
parentd4dab2cd3a409411c2d7caed01fc02a0fd3e8450 (diff)
allow extension of a concrete type to implement additional interface
Also support the scenario that the extension declares conformance to interface I, and a method M in I is already supported by the base implementation.
-rw-r--r--source/slang/check.cpp39
-rw-r--r--source/slang/lower-to-ir.cpp25
-rw-r--r--source/slang/syntax.h52
-rw-r--r--tests/compute/extension-multi-interface.slang49
-rw-r--r--tests/compute/extension-multi-interface.slang.expected.txt4
-rw-r--r--tests/compute/multi-interface.slang45
-rw-r--r--tests/compute/multi-interface.slang.expected.txt4
7 files changed, 189 insertions, 29 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 6c484b493..6b8331060 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -1794,7 +1794,7 @@ namespace Slang
// `requiredMemberDeclRef` is a required member of
// the interface.
RefPtr<Decl> findWitnessForInterfaceRequirement(
- DeclRef<AggTypeDecl> typeDeclRef,
+ DeclRef<AggTypeDeclBase> typeDeclRef,
InheritanceDecl* inheritanceDecl,
DeclRef<InterfaceDecl> interfaceDeclRef,
DeclRef<Decl> requiredMemberDeclRef,
@@ -1833,11 +1833,9 @@ namespace Slang
// Make sure that by-name lookup is possible.
buildMemberDictionary(typeDeclRef.getDecl());
-
- Decl* firstMemberOfName = nullptr;
- typeDeclRef.getDecl()->memberDictionary.TryGetValue(name, firstMemberOfName);
-
- if (!firstMemberOfName)
+ auto lookupResult = lookUpLocal(getSession(), this, name, typeDeclRef);
+
+ if (!lookupResult.isValid())
{
getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, typeDeclRef, requiredMemberDeclRef);
return nullptr;
@@ -1845,10 +1843,10 @@ namespace Slang
// Iterate over the members and look for one that matches
// the expected signature for the requirement.
- for (auto memberDecl = firstMemberOfName; memberDecl; memberDecl = memberDecl->nextInContainerWithSameName)
+ for (auto member : lookupResult)
{
- if (doesMemberSatisfyRequirement(DeclRef<Decl>(memberDecl, typeDeclRef.substitutions), requiredMemberDeclRef, requirementWitness))
- return memberDecl;
+ if (doesMemberSatisfyRequirement(member.declRef, requiredMemberDeclRef, requirementWitness))
+ return member.declRef.getDecl();
}
// No suitable member found, although there were candidates.
@@ -1867,7 +1865,7 @@ namespace Slang
// (via the given `inheritanceDecl`) actually provides
// members to satisfy all the requirements in the interface.
bool checkInterfaceConformance(
- DeclRef<AggTypeDecl> typeDeclRef,
+ DeclRef<AggTypeDeclBase> typeDeclRef,
InheritanceDecl* inheritanceDecl,
DeclRef<InterfaceDecl> interfaceDeclRef)
{
@@ -1925,7 +1923,7 @@ namespace Slang
}
bool checkConformanceToType(
- DeclRef<AggTypeDecl> typeDeclRef,
+ DeclRef<AggTypeDeclBase> typeDeclRef,
InheritanceDecl* inheritanceDecl,
Type* baseType)
{
@@ -1953,7 +1951,7 @@ namespace Slang
// `inheritanceDecl` actually does what it needs to
// for that inheritance to be valid.
bool checkConformance(
- DeclRef<AggTypeDecl> typeDecl,
+ DeclRef<AggTypeDeclBase> typeDecl,
InheritanceDecl* inheritanceDecl)
{
// Look at the type being inherited from, and validate
@@ -1963,10 +1961,10 @@ namespace Slang
}
bool checkConformance(
- AggTypeDecl* typeDecl,
+ AggTypeDeclBase* typeDecl,
InheritanceDecl* inheritanceDecl)
{
- return checkConformance(DeclRef<AggTypeDecl>(typeDecl, SubstitutionSet()), inheritanceDecl);
+ return checkConformance(DeclRef<AggTypeDeclBase>(typeDecl, SubstitutionSet()), inheritanceDecl);
}
void visitAggTypeDecl(AggTypeDecl* decl)
@@ -3479,10 +3477,11 @@ namespace Slang
// TODO: need to check that the target type names a declaration...
+ DeclRef<AggTypeDecl> aggTypeDeclRef;
if (auto targetDeclRefType = decl->targetType->As<DeclRefType>())
{
// Attach our extension to that type as a candidate...
- if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>())
+ if (aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>())
{
auto aggTypeDecl = aggTypeDeclRef.getDecl();
decl->nextCandidateExtension = aggTypeDecl->candidateExtensions;
@@ -3516,6 +3515,14 @@ namespace Slang
EnsureDecl(m);
}
+ if (aggTypeDeclRef)
+ {
+ for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>())
+ {
+ checkConformance(aggTypeDeclRef.getDecl(), inheritanceDecl);
+ }
+ }
+
decl->SetCheckState(DeclCheckState::Checked);
}
@@ -3802,7 +3809,7 @@ namespace Slang
if( auto aggTypeDeclRef = declRef.As<AggTypeDecl>() )
{
- for( auto inheritanceDeclRef : getMembersOfType<InheritanceDecl>(aggTypeDeclRef))
+ for( auto inheritanceDeclRef : getMembersOfTypeWithExt<InheritanceDecl>(aggTypeDeclRef))
{
EnsureDecl(inheritanceDeclRef.getDecl());
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 61ca53278..498783f4b 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -2788,19 +2788,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// TODO: if this inheritance declaration is under an extension,
// then we should construct the type that is being extended,
// and not a reference to the extension itself.
- auto parentDecl = inheritanceDecl->ParentDecl;
- RefPtr<Type> type = DeclRefType::Create(
- context->getSession(),
- makeDeclRef(parentDecl));
+ auto parentDecl = inheritanceDecl->ParentDecl;
+ RefPtr<Type> type;
+ if (auto extParentDecl = dynamic_cast<ExtensionDecl*>(parentDecl))
+ {
+ type = extParentDecl->targetType.type;
+ if (auto declRefType = type.As<DeclRefType>())
+ {
+ if (auto aggTypeDecl = declRefType->declRef.As<AggTypeDecl>())
+ parentDecl = aggTypeDecl.getDecl();
+ }
+ }
+ else
+ {
+ type = DeclRefType::Create(
+ context->getSession(),
+ makeDeclRef(parentDecl));
+ }
// What is the super-type that we have declared we inherit from?
RefPtr<Type> superType = inheritanceDecl->base.type;
// Construct the mangled name for the witness table, which depends
// on the type that is conforming, and the type that it conforms to.
- String mangledName = getMangledNameForConformanceWitness(
- makeDeclRef(parentDecl),
- superType);
+ String mangledName = getMangledNameForConformanceWitness(type, superType);
// Build an IR level witness table, which will represent the
// conformance of the type to its super-type.
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index 375eb5f1c..ab26b1f6d 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -978,6 +978,30 @@ namespace Slang
{
return items.Count() > 1 ? items[0].declRef.GetName() : item.declRef.GetName();
}
+ LookupResultItem* begin()
+ {
+ if (isValid())
+ {
+ if (isOverloaded())
+ return items.begin();
+ else
+ return &item;
+ }
+ else
+ return nullptr;
+ }
+ LookupResultItem* end()
+ {
+ if (isValid())
+ {
+ if (isOverloaded())
+ return items.end();
+ else
+ return &item + 1;
+ }
+ else
+ return nullptr;
+ }
};
struct SemanticsVisitor;
@@ -1085,6 +1109,27 @@ namespace Slang
return FilteredMemberRefList<T>(declRef.getDecl()->Members, declRef.substitutions);
}
+ inline ExtensionDecl* GetCandidateExtensions(DeclRef<AggTypeDecl> const& declRef)
+ {
+ return declRef.getDecl()->candidateExtensions;
+ }
+
+ template<typename T>
+ inline FilteredMemberRefList<T> getMembersOfTypeWithExt(DeclRef<ContainerDecl> const& declRef)
+ {
+ auto rs = getMembersOfType<T>(declRef);
+ if (auto aggDeclRef = declRef.As<AggTypeDecl>())
+ {
+ for (auto ext = GetCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension)
+ {
+ auto extMembers = getMembersOfType<T>(DeclRef<ContainerDecl>(ext, declRef.substitutions));
+ const_cast<List<RefPtr<Decl>>&>(rs.decls).AddRange(extMembers.decls);
+ }
+ }
+ return rs;
+ }
+
+
inline RefPtr<Type> GetType(DeclRef<VarDeclBase> const& declRef)
{
return declRef.Substitute(declRef.getDecl()->type.Ptr());
@@ -1099,12 +1144,7 @@ namespace Slang
{
return declRef.Substitute(declRef.getDecl()->targetType.Ptr());
}
-
- inline ExtensionDecl* GetCandidateExtensions(DeclRef<AggTypeDecl> const& declRef)
- {
- return declRef.getDecl()->candidateExtensions;
- }
-
+
inline FilteredMemberRefList<StructField> GetFields(DeclRef<StructDecl> const& declRef)
{
return getMembersOfType<StructField>(declRef);
diff --git a/tests/compute/extension-multi-interface.slang b/tests/compute/extension-multi-interface.slang
new file mode 100644
index 000000000..6cc88f87c
--- /dev/null
+++ b/tests/compute/extension-multi-interface.slang
@@ -0,0 +1,49 @@
+//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
+
+RWStructuredBuffer<float> outputBuffer;
+
+interface IAdd
+{
+ float addf(float u, float v);
+}
+
+interface ISub
+{
+ float subf(float u, float v);
+}
+
+interface IAddAndSub
+{
+ float addf(float u, float v);
+ float subf(float u, float v);
+}
+
+struct Simple : IAdd
+{
+ float addf(float u, float v)
+ {
+ return u+v;
+ }
+};
+
+__extension Simple : ISub, IAddAndSub
+{
+ float subf(float u, float v)
+ {
+ return u-v;
+ }
+};
+
+float testAddSub<T:IAddAndSub>(T t)
+{
+ return t.subf(t.addf(1.0, 1.0), 1.0);
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ Simple s;
+ float outVal = testAddSub(s);
+ outputBuffer[dispatchThreadID.x] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/extension-multi-interface.slang.expected.txt b/tests/compute/extension-multi-interface.slang.expected.txt
new file mode 100644
index 000000000..e143b7f20
--- /dev/null
+++ b/tests/compute/extension-multi-interface.slang.expected.txt
@@ -0,0 +1,4 @@
+3F800000
+3F800000
+3F800000
+3F800000 \ No newline at end of file
diff --git a/tests/compute/multi-interface.slang b/tests/compute/multi-interface.slang
new file mode 100644
index 000000000..1f9775211
--- /dev/null
+++ b/tests/compute/multi-interface.slang
@@ -0,0 +1,45 @@
+//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
+
+RWStructuredBuffer<float> outputBuffer;
+
+interface IAdd
+{
+ float addf(float u, float v);
+}
+
+interface ISub
+{
+ float subf(float u, float v);
+}
+
+interface IAddAndSub
+{
+ float addf(float u, float v);
+ float subf(float u, float v);
+}
+
+struct Simple : IAdd, ISub, IAddAndSub
+{
+ float addf(float u, float v)
+ {
+ return u+v;
+ }
+ float subf(float u, float v)
+ {
+ return u-v;
+ }
+};
+
+float testAddSub<T:IAddAndSub>(T t)
+{
+ return t.subf(t.addf(1.0, 1.0), 1.0);
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ Simple s;
+ float outVal = testAddSub(s);
+ outputBuffer[dispatchThreadID.x] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/multi-interface.slang.expected.txt b/tests/compute/multi-interface.slang.expected.txt
new file mode 100644
index 000000000..e143b7f20
--- /dev/null
+++ b/tests/compute/multi-interface.slang.expected.txt
@@ -0,0 +1,4 @@
+3F800000
+3F800000
+3F800000
+3F800000 \ No newline at end of file