summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/core/slang-list.h4
-rw-r--r--source/slang/slang-check-decl.cpp20
-rw-r--r--source/slang/slang-syntax.cpp12
-rw-r--r--source/slang/slang-syntax.h1
-rw-r--r--tests/language-feature/interfaces/overloaded-associatedtype.slang43
5 files changed, 78 insertions, 2 deletions
diff --git a/source/core/slang-list.h b/source/core/slang-list.h
index d27afd415..7c96e3844 100644
--- a/source/core/slang-list.h
+++ b/source/core/slang-list.h
@@ -537,7 +537,7 @@ public:
}
}
- inline void swapElements(T* vals, Index index1, Index index2)
+ inline static void swapElements(T* vals, Index index1, Index index2)
{
if (index1 != index2)
{
@@ -547,6 +547,8 @@ public:
}
}
+ inline void swapElements(Index index1, Index index2) { swapElements(m_buffer, index1, index2); }
+
template<typename T2, typename Comparer>
Index binarySearch(const T2& obj, Comparer comparer) const
{
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index c9497ce54..3fea267ee 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -6541,7 +6541,25 @@ bool SemanticsVisitor::findWitnessForInterfaceRequirement(
}
}
}
-
+ if (lookupResult.isOverloaded())
+ {
+ // If we found multiple members with the same name,
+ // we want to move the declarations in the same parent as inheritanceDecl
+ // to the front of the list, so that we always consider them first instead of
+ // the members declared in other extension decls.
+ //
+ Index front = 0;
+ auto parentOfInheritanceDecl = getParentAggTypeDeclBase(inheritanceDecl);
+ for (Index i = 0; i < lookupResult.items.getCount(); i++)
+ {
+ if (getParentAggTypeDeclBase(lookupResult.items[i].declRef.getDecl()) ==
+ parentOfInheritanceDecl)
+ {
+ lookupResult.items.swapElements(i, front);
+ front++;
+ }
+ }
+ }
// Iterate over the members and look for one that matches
// the expected signature for the requirement.
for (auto member : lookupResult)
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 1d3763299..5dc6ca695 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -1051,6 +1051,18 @@ Decl* getParentAggTypeDecl(Decl* decl)
return nullptr;
}
+Decl* getParentAggTypeDeclBase(Decl* decl)
+{
+ decl = decl->parentDecl;
+ while (decl)
+ {
+ if (as<AggTypeDeclBase>(decl))
+ return decl;
+ decl = decl->parentDecl;
+ }
+ return nullptr;
+}
+
Decl* getParentFunc(Decl* decl)
{
while (decl)
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index 5e31b5447..accc490f2 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -370,6 +370,7 @@ Module* getModule(Decl* decl);
/// Get the parent decl, skipping any generic decls in between.
Decl* getParentDecl(Decl* decl);
Decl* getParentAggTypeDecl(Decl* decl);
+Decl* getParentAggTypeDeclBase(Decl* decl);
Decl* getParentFunc(Decl* decl);
} // namespace Slang
diff --git a/tests/language-feature/interfaces/overloaded-associatedtype.slang b/tests/language-feature/interfaces/overloaded-associatedtype.slang
new file mode 100644
index 000000000..4630c76a6
--- /dev/null
+++ b/tests/language-feature/interfaces/overloaded-associatedtype.slang
@@ -0,0 +1,43 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type
+
+interface IFoo<T> {
+ associatedtype Output : IBar;
+ func foo(other: T) -> Output;
+}
+
+interface IBar { int getId(); }
+
+struct Ant:IBar { int getId() { return 0; } };
+struct Bat:IBar { int getId() { return 1; } };
+struct Cat:IBar { int getId() { return 2; } };
+struct Dog:IBar { int getId() { return 3; } };
+struct Ewe:IBar { int getId() { return 4; } };
+struct Fox:IBar { int getId() { return 5; } };
+struct Gnu:IBar { int getId() { return 6; } };
+
+extension Ant: IFoo<Bat> {
+ typedef Cat Output;
+ func foo(other: Bat) -> Cat { return Cat(); }
+}
+extension Ant: IFoo<Dog> {
+ typedef Ewe Output;
+ func foo(other: Dog) -> Ewe { return Ewe(); }
+}
+extension Ant: IFoo<Fox> {
+ typedef Gnu Output;
+ func foo(other: Fox) -> Gnu { return Gnu(); }
+}
+
+int test<T:IFoo<Fox>>(T v) {
+ return v.foo(Fox()).getId();
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=output
+RWStructuredBuffer<int> output;
+
+[numthreads(1,1,1)]
+void computeMain() {
+ Ant a;
+ // CHECK: 6
+ output[0] = test(a);
+} \ No newline at end of file