summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-overload.cpp139
-rw-r--r--tests/bugs/overload-ambiguous-1.slang65
-rw-r--r--tests/bugs/overload-ambiguous-2.slang67
-rw-r--r--tests/bugs/overload-ambiguous.slang19
-rw-r--r--tests/diagnostics/overload-ambiguous.slang45
5 files changed, 294 insertions, 41 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 70eabb4f7..b2173cd7b 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1200,28 +1200,6 @@ namespace Slang
return parent;
}
- void countDistanceToGloablScope(DeclRef<Slang::Decl> const& leftDecl,
- DeclRef<Slang::Decl> const& rightDecl,
- int& leftDistance, int& rightDistance)
- {
- leftDistance = 0;
- rightDistance = 0;
-
- DeclRef<Decl> decl = leftDecl;
- while(decl)
- {
- leftDistance++;
- decl = decl.getParent();
- }
-
- decl = rightDecl;
- while(decl)
- {
- rightDistance++;
- decl = decl.getParent();
- }
- }
-
// Returns -1 if left is preferred, 1 if right is preferred, and 0 if they are equal.
//
int SemanticsVisitor::CompareLookupResultItems(
@@ -1347,23 +1325,6 @@ namespace Slang
}
}
- // We need to consider the distance of the declarations to the global scope to resolve this case:
- // float f(float x);
- // struct S
- // {
- // float f(float x);
- // float g(float y) { return f(y); } // will call S::f() instead of ::f()
- // }
- // We don't need to know the call site of 'f(y)', but only need to count the two candidates' distance to the global scope,
- // because this function will only choose the valid candidates. So if there is situation like this:
- // void main() { S s; s.f(1.0);} or
- // struct T { float g(y) { f(y); } }, there won't be ambiguity.
- // So we just need to count which declaration is farther from the global scope and favor the farther one.
- int leftDistance = 0;
- int rightDistance = 0;
- countDistanceToGloablScope(left.declRef, right.declRef, leftDistance, rightDistance);
- if (leftDistance != rightDistance)
- return leftDistance > rightDistance ? -1 : 1;
// TODO: We should generalize above rules such that in a tie a declaration
// A::m is better than B::m when all other factors are equal and
@@ -1479,6 +1440,70 @@ namespace Slang
return 0;
}
+ int getScopeRank(DeclRef<Decl> const& left,
+ DeclRef<Decl> const& right, Slang::Scope* referenceSiteScope)
+ {
+ if (!referenceSiteScope)
+ return 0;
+
+ DeclRef<Decl> prefixDecl = referenceSiteScope->containerDecl;
+
+ // Hold the path from reference site to the root
+ // key: Decl node, value: distance from reference site
+ Dictionary<Decl*, uint32_t> refPath;
+ for (auto node = prefixDecl; node != nullptr; node = node.getParent())
+ {
+ Decl* key = node.getDecl();
+ uint32_t value = (uint32_t)refPath.getCount();
+ refPath.add(key, value);
+ }
+
+ // find the common prefix decl of reference site and left
+ int leftDistance = 0;
+ int rightDistance = 0;
+ auto distanceToCommonPrefix = [](DeclRef<Decl> const& candidate, Dictionary<Decl*, uint32_t> refPath) -> int
+ {
+ uint32_t distanceToReferenceSite = 0;
+ uint32_t distanceToCandidate = 0;
+
+ // Sanity check
+ if (candidate.getDecl() == nullptr)
+ return -1;
+
+ // search from candidate to root, once we found the first node in the reference path, that is the first
+ // common prefix, and we can stop searching.
+ for (auto node = candidate; node != nullptr; node = node.getParent())
+ {
+ Decl* key = node.getDecl();
+ if (refPath.tryGetValue(key, distanceToReferenceSite))
+ {
+ break;
+ }
+ distanceToCandidate++;
+ }
+
+ // If we don't find the common prefix, there must be something wrong, return the max value.
+ if (distanceToReferenceSite == 0)
+ return -1;
+
+ return distanceToReferenceSite + distanceToCandidate;
+ };
+
+ leftDistance = distanceToCommonPrefix(left, refPath);
+ rightDistance = distanceToCommonPrefix(right, refPath);
+
+ if (leftDistance == rightDistance)
+ return 0;
+
+ if (leftDistance == -1)
+ return 1;
+
+ if (rightDistance == -1)
+ return -1;
+
+ return leftDistance < rightDistance ? -1 : 1;
+ }
+
int SemanticsVisitor::CompareOverloadCandidates(
OverloadCandidate* left,
OverloadCandidate* right)
@@ -1558,6 +1583,42 @@ namespace Slang
if (externExportDiff)
return externExportDiff;
+ // We need to consider the distance of the declarations to the global scope to resolve this case:
+ // float f(float x);
+ // struct S
+ // {
+ // float f(float x);
+ // float g(float y) { return f(y); } // will call S::f() instead of ::f()
+ // }
+ // we will count the distance from the reference site to the declaration in the scope tree.
+
+ // NOTE: We CAN'T do this for the generic function, because generic lookup is little bit complicated.
+ // It will go through multiple passes of candidates compare.
+ // In the first pass, it will lookup all the generic candidates that matches the generic parameter only,
+ // e.g., the following generic functions are totally different, but they will be selected as candidates
+ // because the function name and the generic parameters are the same:
+ // void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b);
+ // void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b, Z0 c);
+ // void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b, Z0 c, Z1 d);
+ //
+ // So in this case, we should not consider the scope rank and overload rank at all, because there is only
+ // one of above candidates is valid, and the rank calculation doesn't consider the correctness of the
+ // candidates, so it could select the wrong candidate.
+ //
+ // In the next pass, the lookup system will match the input parameters in those candidates to find out the valid
+ // match, the "flavor" field will become "Func" or "Expr". So the rank calculation can be applied.
+ if (left->flavor == OverloadCandidate::Flavor::Generic ||
+ left->flavor == OverloadCandidate::Flavor::UnspecializedGeneric ||
+ right->flavor == OverloadCandidate::Flavor::Generic ||
+ right->flavor == OverloadCandidate::Flavor::UnspecializedGeneric)
+ {
+ return 0;
+ }
+
+ auto scopeRank = getScopeRank(left->item.declRef, right->item.declRef, this->m_outerScope);
+ if (scopeRank)
+ return scopeRank;
+
// If we reach here, we will attempt to use overload rank to break the ties.
auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef);
if (overloadRankDiff)
diff --git a/tests/bugs/overload-ambiguous-1.slang b/tests/bugs/overload-ambiguous-1.slang
new file mode 100644
index 000000000..9f9c6e5bc
--- /dev/null
+++ b/tests/bugs/overload-ambiguous-1.slang
@@ -0,0 +1,65 @@
+// https://github.com/shader-slang/slang/issues/4476
+
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+namespace A1
+{
+ uint func()
+ {
+ return 1u;
+ }
+
+ namespace A2
+ {
+ uint func()
+ {
+ return 2u;
+ }
+
+ namespace A3
+ {
+ uint func()
+ {
+ return 3u;
+ }
+
+ uint test2()
+ {
+ return func(); // choose A3::func()
+ }
+ }
+
+ namespace A4
+ {
+ uint test()
+ {
+ return func(); // choose A2::func()
+ }
+ }
+ }
+}
+
+[numthreads(1, 1, 1)]
+[shader("compute")]
+void computeMain(uint3 threadID: SV_DispatchThreadID)
+{
+ using namespace A1;
+ using namespace A1::A2;
+ using namespace A1::A2::A3;
+ using namespace A1::A2::A4;
+ outputBuffer[0] = test();
+ // BUF: 2
+
+ outputBuffer[1] = func(); // choose the A1::func()
+ // BUF-NEXT: 1
+
+ outputBuffer[2] = test2();
+ // BUF-NEXT: 3
+}
diff --git a/tests/bugs/overload-ambiguous-2.slang b/tests/bugs/overload-ambiguous-2.slang
new file mode 100644
index 000000000..46af9f091
--- /dev/null
+++ b/tests/bugs/overload-ambiguous-2.slang
@@ -0,0 +1,67 @@
+// https://github.com/shader-slang/slang/issues/4476
+
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+namespace A
+{
+ struct Struct1<let SIZE : uint>
+ {
+ uint data;
+ };
+
+ Struct1<Z1> myFunc<let Z0 : uint, let Z1 : uint>(Struct1<Z0> inputS1)
+ {
+ Struct1<Z1> s1;
+ s1.data = inputS1.data + 2U;
+ return s1;
+ }
+};
+
+
+A::Struct1<Z1> myFunc<let Z0 : uint, let Z1 : uint>(A::Struct1<Z0> inputS1)
+{
+ A::Struct1<Z1> s1;
+ s1.data = inputS1.data + 5U;
+ return s1;
+}
+
+namespace A
+{
+ struct Struct2<let SIZE : uint>
+ {
+ Struct1<SIZE> s1;
+ }
+
+ Struct2<Z1> myFunc<let Z0 : uint, let Z1 : uint>(Struct2<Z0> inputS2)
+ {
+ Struct2<Z1> s2;
+ // We want to cover a corner case in our compiler where:
+ // when looking up "myFunc", the compiler should find
+ // Struct1<Z1> A::myFunc<let Z0 : uint, let Z1 : uint>(Struct1<Z0> inputS1)
+ // and it won't be ambiguous with the global "myFunc".
+ s2.s1 = myFunc<Z0, Z1>(inputS2.s1);
+ return s2;
+ }
+};
+
+[numthreads(1, 1, 1)]
+[shader("compute")]
+void computeMain(uint3 threadID: SV_DispatchThreadID)
+{
+ using namespace A;
+
+ Struct2<10> input = {threadID.x};
+
+ Struct2<20> output;
+ output = myFunc<10, 20>(input);
+ outputBuffer[0] = output.s1.data;
+
+ // BUF: 2
+}
diff --git a/tests/bugs/overload-ambiguous.slang b/tests/bugs/overload-ambiguous.slang
index 1b74cb68c..d764f72e4 100644
--- a/tests/bugs/overload-ambiguous.slang
+++ b/tests/bugs/overload-ambiguous.slang
@@ -6,7 +6,7 @@
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -dx12 -shaderobj
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj
-//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<uint> outputBuffer;
@@ -34,7 +34,18 @@ struct DataObtainer
}
}
-RWStructuredBuffer<uint> output;
+uint myFunc(uint a)
+{
+ return a + 1u;
+}
+
+__generic<T: __BuiltinIntegerType>
+uint myFunc(T a)
+{
+ uint b = __intCast<uint, T>(a);
+ return b + 2u;
+}
+
[numthreads(1, 1, 1)]
[shader("compute")]
@@ -43,6 +54,10 @@ void computeMain(uint3 threadID: SV_DispatchThreadID)
DataObtainer obtainer = {2u};
outputBuffer[0] = obtainer.getValue();
outputBuffer[1] = obtainer.getValue2();
+
+ uint a = 1u;
+ outputBuffer[2] = myFunc(a); // will call myFunc(uint) which more specialized
// BUF: 2
// BUF-NEXT: 1
+ // BUF-NEXT: 2
}
diff --git a/tests/diagnostics/overload-ambiguous.slang b/tests/diagnostics/overload-ambiguous.slang
new file mode 100644
index 000000000..0c8f7bd21
--- /dev/null
+++ b/tests/diagnostics/overload-ambiguous.slang
@@ -0,0 +1,45 @@
+// https://github.com/shader-slang/slang/issues/4476
+
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+RWStructuredBuffer<uint> outputBuffer;
+
+namespace A1
+{
+ uint func()
+ {
+ return 1u;
+ }
+
+ namespace A2
+ {
+ uint func()
+ {
+ return 2u;
+ }
+ }
+}
+namespace B1
+{
+ uint func()
+ {
+ return 4u;
+ }
+}
+
+[numthreads(1, 1, 1)]
+[shader("compute")]
+void computeMain(uint3 threadID: SV_DispatchThreadID)
+{
+ using namespace A1;
+ using namespace A1::A2;
+ using namespace B1;
+ using namespace C1;
+
+ // Only A1::func() and B1::func() will cause ambiguity because the distance from
+ // the reference site to those two functions declaration are the same.
+ outputBuffer[0] = func();
+ // CHECK-NOT: {{.*}}A2::func() -> uint
+ // CHECK: ambiguous call to 'func' with arguments of type ()
+ // CHECK: candidate: func B1::func() -> uint
+ // CHECK: candidate: func A1::func() -> uint
+}