summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-07-07 14:54:54 -0700
committerGitHub <noreply@github.com>2025-07-07 21:54:54 +0000
commit3865a6596afca1c193eb17bbb74008077096e7c3 (patch)
tree0d04cab0ad720b75027ddcee855daf6f6eba5d57 /source
parent7119d9cb487d866d1c25e55eafa03aca6e5e52e3 (diff)
Language server: sort completion candidate by relevance. (#7626)
* Language server: sort completion candidate by relevance. * Small adjustment.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp13
-rw-r--r--source/slang/slang-language-server-ast-lookup.h1
-rw-r--r--source/slang/slang-language-server-completion.cpp229
-rw-r--r--source/slang/slang-language-server-completion.h2
4 files changed, 239 insertions, 6 deletions
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index 1ae9e4a0b..1f01baadc 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -7,6 +7,7 @@ namespace Slang
{
struct ASTLookupContext
{
+ Linkage* linkage;
DocumentVersion* doc;
SourceManager* sourceManager;
List<SyntaxNode*> nodePath;
@@ -167,6 +168,7 @@ public:
bool visitGenericAppExpr(GenericAppExpr* genericAppExpr)
{
+ PushNode pushNodeRAII(context, genericAppExpr);
if (dispatchIfNotNull(genericAppExpr->functionExpr))
return true;
for (auto arg : genericAppExpr->arguments)
@@ -231,7 +233,15 @@ public:
return true;
}
}
-
+ if (this->context->findType == ASTLookupType::CompletionRequest &&
+ expr->name == context->linkage->getSessionImpl()->getCompletionRequestTokenName())
+ {
+ ASTLookupResult result;
+ result.path = context->nodePath;
+ result.path.add(expr);
+ context->results.add(result);
+ return true;
+ }
return dispatchIfNotNull(expr->originalExpr);
}
@@ -861,6 +871,7 @@ List<ASTLookupResult> findASTNodesAt(
context.findType = findType;
context.sourceFileName = fileName;
context.doc = doc;
+ context.linkage = getModule(moduleDecl)->getLinkage();
_findAstNodeImpl(context, moduleDecl);
return context.results;
}
diff --git a/source/slang/slang-language-server-ast-lookup.h b/source/slang/slang-language-server-ast-lookup.h
index 6970667af..6d9d3be6c 100644
--- a/source/slang/slang-language-server-ast-lookup.h
+++ b/source/slang/slang-language-server-ast-lookup.h
@@ -13,6 +13,7 @@ enum class ASTLookupType
{
Decl,
Invoke,
+ CompletionRequest,
};
struct Loc
diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp
index b5864d972..e1758ca5d 100644
--- a/source/slang/slang-language-server-completion.cpp
+++ b/source/slang/slang-language-server-completion.cpp
@@ -580,6 +580,217 @@ String CompletionContext::formatDeclForCompletion(
return printer.getString();
}
+// Returns true if `exprNode` is the same as `targetExpr`, or if the original expr node
+// of `exprNode` before any checking/transformation is the same as `targetExpr`.
+bool matchExpr(Expr* exprNode, SyntaxNode* targetExpr)
+{
+ if (!exprNode)
+ return false;
+ if (exprNode == targetExpr)
+ return true;
+ if (auto invokeExpr = as<AppExprBase>(exprNode))
+ return matchExpr(invokeExpr->originalFunctionExpr, targetExpr);
+ if (auto overloadedExpr = as<OverloadedExpr>(exprNode))
+ return matchExpr(overloadedExpr->originalExpr, targetExpr);
+ if (auto partiallyAppliedExpr = as<PartiallyAppliedGenericExpr>(exprNode))
+ return matchExpr(partiallyAppliedExpr->originalExpr, targetExpr);
+ if (auto extractExistentialExpr = as<ExtractExistentialValueExpr>(exprNode))
+ return matchExpr(extractExistentialExpr->originalExpr, targetExpr);
+ if (auto declRefExpr = as<DeclRefExpr>(exprNode))
+ return matchExpr(declRefExpr->originalExpr, targetExpr);
+ return false;
+}
+
+// Infer the accepted types at the completion position based on the AST nodes.
+//
+List<Type*> CompletionContext::getExpectedTypesAtCompletion(const List<ASTLookupResult>& astNodes)
+{
+ List<Type*> expectedType;
+ if (astNodes.getCount() == 0)
+ return expectedType;
+ auto& path = astNodes.getFirst().path;
+ if (path.getCount() < 2)
+ return expectedType;
+ auto completionExprNode = path.getLast();
+ auto parentNode = path[path.getCount() - 2];
+ auto collectArgumentType = [&](AppExprBase* appExpr, Index argIndex)
+ {
+ if (!appExpr)
+ return;
+ auto functionExpr = appExpr->functionExpr;
+ if (!functionExpr)
+ return;
+ if (as<InvokeExpr>(appExpr))
+ {
+ // If we are in an invoke expr, we will use the parameter type of the
+ // callee as the expected type.
+ auto processDeclRefCallee = [&](DeclRef<Decl> calleeDeclRef)
+ {
+ auto decl = calleeDeclRef.getDecl();
+ auto callableDecl = as<CallableDecl>(decl);
+ if (!callableDecl)
+ return;
+ Index paramIndex = 0;
+ for (auto paramDeclRef :
+ getMembersOfType<ParamDecl>(version->linkage->getASTBuilder(), callableDecl))
+ {
+ if (paramIndex == argIndex)
+ {
+ expectedType.add(getType(version->linkage->getASTBuilder(), paramDeclRef));
+ return;
+ }
+ paramIndex++;
+ }
+ };
+ if (auto declRefExpr = as<DeclRefExpr>(functionExpr))
+ processDeclRefCallee(declRefExpr->declRef);
+ else if (auto overloadedExpr = as<OverloadedExpr>(functionExpr))
+ {
+ for (auto& lookupResult : overloadedExpr->lookupResult2)
+ processDeclRefCallee(lookupResult.declRef);
+ }
+ }
+ else if (as<GenericAppExpr>(appExpr))
+ {
+ auto declRefExpr = as<DeclRefExpr>(functionExpr);
+ if (!declRefExpr)
+ return;
+ auto genericDecl = as<GenericDecl>(declRefExpr->declRef.getDecl());
+ if (!genericDecl)
+ return;
+
+ for (auto member : genericDecl->getMembers())
+ {
+ if (auto valParamDecl = as<GenericValueParamDecl>(member))
+ {
+ if (valParamDecl->parameterIndex == argIndex)
+ {
+ expectedType.add(valParamDecl->type.type);
+ return;
+ }
+ }
+ }
+ }
+ };
+
+ if (auto implicitCastExpr = as<ImplicitCastExpr>(parentNode))
+ {
+ // If the completion request is in (SomeType)(!completionRequest), then we should prefer any
+ // candidates that has `SomeType`.
+ if (implicitCastExpr->arguments.getCount() == 1 &&
+ matchExpr(implicitCastExpr->arguments[0], completionExprNode))
+ {
+ if (as<DeclRefType>(implicitCastExpr->type.type))
+ expectedType.add(implicitCastExpr->type.type);
+ }
+ return expectedType;
+ }
+ if (auto invokeExpr = as<AppExprBase>(parentNode))
+ {
+ // If parent node is an invoke expr, check if we are in an argument position.
+ for (Index i = 0; i < invokeExpr->arguments.getCount(); i++)
+ {
+ if (matchExpr(invokeExpr->arguments[i], completionExprNode))
+ {
+ // If we are in an argument position, we will use the expected type of the
+ // argument.
+ collectArgumentType(invokeExpr, i);
+ break;
+ }
+ }
+ return expectedType;
+ }
+ if (auto varDecl = as<VarDeclBase>(parentNode))
+ {
+ if (!varDecl)
+ return expectedType;
+ if (!matchExpr(varDecl->initExpr, completionExprNode))
+ return expectedType;
+ if (as<DeclRefType>(varDecl->type.type))
+ {
+ expectedType.add(varDecl->type.type);
+ }
+ return expectedType;
+ }
+ return expectedType;
+}
+
+Index CompletionContext::determineCompletionItemSortOrder(
+ Decl* item,
+ const List<Type*>& expectedTypes)
+{
+ if (expectedTypes.getCount() == 0)
+ return -1;
+
+ // Test if `itemType` matches `expectedType`, and return the relevance of the match.
+ // -1 means no match, a positive number means a match.
+ // The smaller the number, the more relevant the match is, and the item will be listed
+ // earlier in the completion list.
+ auto matchType = [&](Type* itemType, DeclRefType* expectedType) -> Index
+ {
+ if (itemType == expectedType)
+ return 1; // Exact match
+
+ auto declRef = isDeclRefTypeOf<Decl>(itemType);
+ if (!declRef)
+ return -1; // No match
+
+ if (declRef.getDecl() == expectedType->getDeclRef().getDecl())
+ return 2; // Match by decl
+
+ // We may also want to extend the matching logic to include subtyping or other
+ // coercion relationships. But for now, we will just check for simple matches
+ // to avoid performance problems.
+ //
+ return -1;
+ };
+
+ Index result = -1;
+
+ // If we have any expected types, we will sort the completion candiate items by their relevance
+ // to the expected types.
+ // If the item has expected type, we will assign a sort order to make it appear at the top
+ // of the completion list.
+ for (auto et : expectedTypes)
+ {
+ Index currentSortOrder = -1;
+ auto etDeclRefType = as<DeclRefType>(et);
+ if (!etDeclRefType)
+ continue;
+ if (item == etDeclRefType->getDeclRef().getDecl())
+ {
+ if (as<EnumDecl>(item))
+ currentSortOrder = 0;
+ else if (!as<InterfaceDecl>(item))
+ currentSortOrder = 1;
+ }
+ else if (auto varItem = as<VarDeclBase>(item))
+ {
+ currentSortOrder = matchType(varItem->type.type, etDeclRefType);
+ }
+ else if (auto callableItem = as<CallableDecl>(item))
+ {
+ // If the item is a callable decl, we will check if the return type matches the expected
+ // type.
+ currentSortOrder = matchType(callableItem->returnType.type, etDeclRefType);
+ }
+ if (result == -1 || (currentSortOrder != -1 && currentSortOrder < result))
+ {
+ // If we have a better match, we will update the result.
+ result = currentSortOrder;
+ }
+ }
+ // Always list decls within the same module first.
+ // Note if result == 0, it means the item is representing the expected enum type itself,
+ // so we always want to list it first by not increasing `result`.
+ if (result > 0 && getModule(item) != parsedModule)
+ result++;
+ // List core module decls last.
+ if (result > 0 && isFromCoreModule(item))
+ result++;
+ return result;
+}
+
CompletionResult CompletionContext::collectMembersAndSymbols()
{
List<LanguageServerProtocol::CompletionItem> result;
@@ -626,7 +837,15 @@ CompletionResult CompletionContext::collectMembersAndSymbols()
addKeywords = false;
break;
}
-
+ auto lookupResults = findASTNodesAt(
+ doc,
+ version->linkage->getSourceManager(),
+ parsedModule->getModuleDecl(),
+ ASTLookupType::CompletionRequest,
+ canonicalPath,
+ line,
+ col);
+ auto expectedTypes = getExpectedTypesAtCompletion(lookupResults);
HashSet<String> deduplicateSet;
for (Index i = 0;
i < linkage->contentAssistInfo.completionSuggestions.candidateItems.getCount();
@@ -734,13 +953,13 @@ CompletionResult CompletionContext::collectMembersAndSymbols()
item.kind = LanguageServerProtocol::kCompletionItemKindClass;
}
item.data = String(i);
- if (linkage->contentAssistInfo.completionSuggestions.formatMode !=
- CompletionSuggestions::FormatMode::Name)
+
+ Index sortOrder = determineCompletionItemSortOrder(member, expectedTypes);
+ if (sortOrder != -1)
{
item.sortText =
- (StringBuilder() << i << ":" << getText(member->getName())).produceString();
+ (StringBuilder() << sortOrder << ":" << getText(member->getName())).produceString();
}
-
result.add(item);
if (nameStart > 1)
{
diff --git a/source/slang/slang-language-server-completion.h b/source/slang/slang-language-server-completion.h
index 513538ac8..57b4fbda7 100644
--- a/source/slang/slang-language-server-completion.h
+++ b/source/slang/slang-language-server-completion.h
@@ -56,6 +56,8 @@ struct CompletionContext
Index fileNameStartPos,
bool isImportString);
+ List<Type*> getExpectedTypesAtCompletion(const List<ASTLookupResult>& astNodes);
+ Index determineCompletionItemSortOrder(Decl* item, const List<Type*>& expectedTypes);
CompletionResult collectMembersAndSymbols();
String formatDeclForCompletion(