diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 139 |
1 files changed, 100 insertions, 39 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 70eabb4f7..b2173cd7b 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1200,28 +1200,6 @@ namespace Slang return parent; } - void countDistanceToGloablScope(DeclRef<Slang::Decl> const& leftDecl, - DeclRef<Slang::Decl> const& rightDecl, - int& leftDistance, int& rightDistance) - { - leftDistance = 0; - rightDistance = 0; - - DeclRef<Decl> decl = leftDecl; - while(decl) - { - leftDistance++; - decl = decl.getParent(); - } - - decl = rightDecl; - while(decl) - { - rightDistance++; - decl = decl.getParent(); - } - } - // Returns -1 if left is preferred, 1 if right is preferred, and 0 if they are equal. // int SemanticsVisitor::CompareLookupResultItems( @@ -1347,23 +1325,6 @@ namespace Slang } } - // We need to consider the distance of the declarations to the global scope to resolve this case: - // float f(float x); - // struct S - // { - // float f(float x); - // float g(float y) { return f(y); } // will call S::f() instead of ::f() - // } - // We don't need to know the call site of 'f(y)', but only need to count the two candidates' distance to the global scope, - // because this function will only choose the valid candidates. So if there is situation like this: - // void main() { S s; s.f(1.0);} or - // struct T { float g(y) { f(y); } }, there won't be ambiguity. - // So we just need to count which declaration is farther from the global scope and favor the farther one. - int leftDistance = 0; - int rightDistance = 0; - countDistanceToGloablScope(left.declRef, right.declRef, leftDistance, rightDistance); - if (leftDistance != rightDistance) - return leftDistance > rightDistance ? -1 : 1; // TODO: We should generalize above rules such that in a tie a declaration // A::m is better than B::m when all other factors are equal and @@ -1479,6 +1440,70 @@ namespace Slang return 0; } + int getScopeRank(DeclRef<Decl> const& left, + DeclRef<Decl> const& right, Slang::Scope* referenceSiteScope) + { + if (!referenceSiteScope) + return 0; + + DeclRef<Decl> prefixDecl = referenceSiteScope->containerDecl; + + // Hold the path from reference site to the root + // key: Decl node, value: distance from reference site + Dictionary<Decl*, uint32_t> refPath; + for (auto node = prefixDecl; node != nullptr; node = node.getParent()) + { + Decl* key = node.getDecl(); + uint32_t value = (uint32_t)refPath.getCount(); + refPath.add(key, value); + } + + // find the common prefix decl of reference site and left + int leftDistance = 0; + int rightDistance = 0; + auto distanceToCommonPrefix = [](DeclRef<Decl> const& candidate, Dictionary<Decl*, uint32_t> refPath) -> int + { + uint32_t distanceToReferenceSite = 0; + uint32_t distanceToCandidate = 0; + + // Sanity check + if (candidate.getDecl() == nullptr) + return -1; + + // search from candidate to root, once we found the first node in the reference path, that is the first + // common prefix, and we can stop searching. + for (auto node = candidate; node != nullptr; node = node.getParent()) + { + Decl* key = node.getDecl(); + if (refPath.tryGetValue(key, distanceToReferenceSite)) + { + break; + } + distanceToCandidate++; + } + + // If we don't find the common prefix, there must be something wrong, return the max value. + if (distanceToReferenceSite == 0) + return -1; + + return distanceToReferenceSite + distanceToCandidate; + }; + + leftDistance = distanceToCommonPrefix(left, refPath); + rightDistance = distanceToCommonPrefix(right, refPath); + + if (leftDistance == rightDistance) + return 0; + + if (leftDistance == -1) + return 1; + + if (rightDistance == -1) + return -1; + + return leftDistance < rightDistance ? -1 : 1; + } + int SemanticsVisitor::CompareOverloadCandidates( OverloadCandidate* left, OverloadCandidate* right) @@ -1558,6 +1583,42 @@ namespace Slang if (externExportDiff) return externExportDiff; + // We need to consider the distance of the declarations to the global scope to resolve this case: + // float f(float x); + // struct S + // { + // float f(float x); + // float g(float y) { return f(y); } // will call S::f() instead of ::f() + // } + // we will count the distance from the reference site to the declaration in the scope tree. + + // NOTE: We CAN'T do this for the generic function, because generic lookup is little bit complicated. + // It will go through multiple passes of candidates compare. + // In the first pass, it will lookup all the generic candidates that matches the generic parameter only, + // e.g., the following generic functions are totally different, but they will be selected as candidates + // because the function name and the generic parameters are the same: + // void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b); + // void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b, Z0 c); + // void func<let Z0 : uint, let Z1 : uint>(Z0 a, Z1 b, Z0 c, Z1 d); + // + // So in this case, we should not consider the scope rank and overload rank at all, because there is only + // one of above candidates is valid, and the rank calculation doesn't consider the correctness of the + // candidates, so it could select the wrong candidate. + // + // In the next pass, the lookup system will match the input parameters in those candidates to find out the valid + // match, the "flavor" field will become "Func" or "Expr". So the rank calculation can be applied. + if (left->flavor == OverloadCandidate::Flavor::Generic || + left->flavor == OverloadCandidate::Flavor::UnspecializedGeneric || + right->flavor == OverloadCandidate::Flavor::Generic || + right->flavor == OverloadCandidate::Flavor::UnspecializedGeneric) + { + return 0; + } + + auto scopeRank = getScopeRank(left->item.declRef, right->item.declRef, this->m_outerScope); + if (scopeRank) + return scopeRank; + // If we reach here, we will attempt to use overload rank to break the ties. auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef); if (overloadRankDiff) |
