summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-overload.cpp139
1 files changed, 100 insertions, 39 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 70eabb4f7..b2173cd7b 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1200,28 +1200,6 @@ namespace Slang
return parent;
}
- void countDistanceToGloablScope(DeclRef<Slang::Decl> const& leftDecl,
- DeclRef<Slang::Decl> const& rightDecl,
- int& leftDistance, int& rightDistance)
- {
- leftDistance = 0;
- rightDistance = 0;
-
- DeclRef<Decl> decl = leftDecl;
- while(decl)
- {
- leftDistance++;
- decl = decl.getParent();
- }
-
- decl = rightDecl;
- while(decl)
- {
- rightDistance++;
- decl = decl.getParent();
- }
- }
-
// Returns -1 if left is preferred, 1 if right is preferred, and 0 if they are equal.
//
int SemanticsVisitor::CompareLookupResultItems(
@@ -1347,23 +1325,6 @@ namespace Slang
}
}
- // We need to consider the distance of the declarations to the global scope to resolve this case:
- // float f(float x);
- // struct S
- // {
- // float f(float x);
- // float g(float y) { return f(y); } // will call S::f() instead of ::f()
- // }
- // We don't need to know the call site of 'f(y)', but only need to count the two candidates' distance to the global scope,
- // because this function will only choose the valid candidates. So if there is situation like this:
- // void main() { S s; s.f(1.0);} or
- // struct T { float g(y) { f(y); } }, there won't be ambiguity.
- // So we just need to count which declaration is farther from the global scope and favor the farther one.
- int leftDistance = 0;
- int rightDistance = 0;
- countDistanceToGloablScope(left.declRef, right.declRef, leftDistance, rightDistance);
- if (leftDistance != rightDistance)
- return leftDistance > rightDistance ? -1 : 1;
// TODO: We should generalize above rules such that in a tie a declaration
// A::m is better than B::m when all other factors are equal and
@@ -1479,6 +1440,70 @@ namespace Slang
return 0;
}
+ int getScopeRank(DeclRef<Decl> const& left,
+ DeclRef<Decl> const& right, Slang::Scope* referenceSiteScope)
+ {
+ if (!referenceSiteScope)
+ return 0;
+
+ DeclRef<Decl> prefixDecl = referenceSiteScope->containerDecl;
+
+ // Hold the path from reference site to the root
+ // key: Decl node, value: distance from reference site
+ Dictionary<Decl*, uint32_t> refPath;
+ for (auto node = prefixDecl; node != nullptr; node = node.getParent())
+ {
+ Decl* key = node.getDecl();
+ uint32_t value = (uint32_t)refPath.getCount();
+ refPath.add(key, value);
+ }
+
+ // find the common prefix decl of reference site and left
+ int leftDistance = 0;
+ int rightDistance = 0;
+ auto distanceToCommonPrefix = [](DeclRef<Decl> const& candidate, Dictionary<Decl*, uint32_t> refPath) -> int
+ {
+ uint32_t distanceToReferenceSite = 0;
+ uint32_t distanceToCandidate = 0;
+
+ // Sanity check
+ if (candidate.getDecl() == nullptr)
+ return -1;
+
+ // search from candidate to root, once we found the first node in the reference path, that is the first
+ // common prefix, and we can stop searching.
+ for (auto node = candidate; node != nullptr; node = node.getParent())
+ {
+ Decl* key = node.getDecl();
+ if (refPath.tryGetValue(key, distanceToReferenceSite))
+ {
+ break;
+ }
+ distanceToCandidate++;
+ }
+
+ // If we don't find the common prefix, there must be something wrong, return the max value.
+ if (distanceToReferenceSite == 0)
+ return -1;
+
+ return distanceToReferenceSite + distanceToCandidate;
+ };
+
+ leftDistance = distanceToCommonPrefix(left, refPath);
+ rightDistance = distanceToCommonPrefix(right, refPath);
+
+ if (leftDistance == rightDistance)
+ return 0;
+
+ if (leftDistance == -1)
+ return 1;
+
+ if (rightDistance == -1)
+ return -1;
+
+ return leftDistance < rightDistance ? -1 : 1;
+ }
+
int SemanticsVisitor::CompareOverloadCandidates(
OverloadCandidate* left,
OverloadCandidate* right)
@@ -1558,6 +1583,42 @@ namespace Slang
if (externExportDiff)
return externExportDiff;
+ // We need to consider the distance of the declarations to the global scope to resolve this case:
+ // float f(float x);
+ // struct S
+ // {
+ // float f(float x);
+ // float g(float y) { return f(y); } // will call S::f() instead of ::f()
+ // }
+ // we will count the distance from the reference site to the declaration in the scope tree.
+
+ // NOTE: We CAN'T do this for the generic function, because generic lookup is little bit complicated.
+ // It will go through multiple passes of candidates compare.
+ // In the first pass, it will lookup all the generic candidates that matches the generic parameter only,
+ // e.g., the following generic functions are totally different, but they will be selected as candidates
+ // because the function name and the generic parameters are the same:
+ // void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b);
+ // void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b, Z0 c);
+ // void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b, Z0 c, Z1 d);
+ //
+ // So in this case, we should not consider the scope rank and overload rank at all, because there is only
+ // one of above candidates is valid, and the rank calculation doesn't consider the correctness of the
+ // candidates, so it could select the wrong candidate.
+ //
+ // In the next pass, the lookup system will match the input parameters in those candidates to find out the valid
+ // match, the "flavor" field will become "Func" or "Expr". So the rank calculation can be applied.
+ if (left->flavor == OverloadCandidate::Flavor::Generic ||
+ left->flavor == OverloadCandidate::Flavor::UnspecializedGeneric ||
+ right->flavor == OverloadCandidate::Flavor::Generic ||
+ right->flavor == OverloadCandidate::Flavor::UnspecializedGeneric)
+ {
+ return 0;
+ }
+
+ auto scopeRank = getScopeRank(left->item.declRef, right->item.declRef, this->m_outerScope);
+ if (scopeRank)
+ return scopeRank;
+
// If we reach here, we will attempt to use overload rank to break the ties.
auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef);
if (overloadRankDiff)