From b4fc380af5e390ca11892f9e657e653f6869c21b Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 3 Jul 2025 15:20:23 -0700 Subject: Language Server Enhancements (#7604) * Language Server: auto-select the best candidate in signature help. * Fix constructor call highlighting + goto definition. * Add test. * format code * Improve ctor signature help. * Add tests. * Fix decl path printing for extension children. * Allow goto definition to show core module source. * c++ compile fix. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --- source/slang/slang-ast-iterator.h | 1 + source/slang/slang-ast-print.cpp | 6 +- source/slang/slang-check-impl.h | 2 +- source/slang/slang-check-overload.cpp | 35 +++-- source/slang/slang-compiler.h | 2 + source/slang/slang-language-server-ast-lookup.cpp | 10 -- source/slang/slang-language-server.cpp | 177 ++++++++++++++++++++-- source/slang/slang-language-server.h | 3 + source/slang/slang-workspace-version.cpp | 5 +- source/slang/slang.cpp | 27 ++-- 10 files changed, 224 insertions(+), 44 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 094a9c1a2..3cce8df59 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -133,6 +133,7 @@ struct ASTIterator iterator->maybeDispatchCallback(expr); dispatchIfNotNull(expr->functionExpr); + dispatchIfNotNull(expr->originalFunctionExpr); for (auto arg : expr->arguments) dispatchIfNotNull(arg); } diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index 4b5a69f15..ee747a4c2 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -1129,6 +1129,7 @@ void ASTPrinter::addVal(Val* val) /* static */ void ASTPrinter::appendDeclName(Decl* decl, StringBuilder& out) { + decl = maybeGetInner(decl); if (as(decl)) { out << "init"; @@ -1231,8 +1232,7 @@ void ASTPrinter::_addDeclPathRec(const DeclRef& declRef, Index depth) } else if (auto extensionDeclRef = parentDeclRef.as()) { - ExtensionDecl* extensionDecl = as(parentDeclRef.getDecl()); - Type* type = extensionDecl->targetType.type; + Type* type = getTargetType(m_astBuilder, extensionDeclRef); if (m_optionFlags & OptionFlag::NoSpecializedExtensionTypeName) { if (auto unspecializedDeclRef = isDeclRefTypeOf(type)) @@ -1522,6 +1522,8 @@ void ASTPrinter::addDeclKindPrefix(Decl* decl) continue; if (as(modifier)) continue; + if (as(modifier)) + continue; } // Don't print out attributes. if (as(modifier)) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 7ddec20fb..30e317401 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1252,8 +1252,8 @@ public: TypeExp TranslateTypeNode(TypeExp const& typeExp); Type* getRemovedModifierType(ModifiedType* type, ModifierVal* modifier); Type* getConstantBufferType(Type* elementType, Type* layoutType); - DeclRefType* getExprDeclRefType(Expr* expr); + LookupResult lookupConstructorsInType(Type* type, Scope* sourceScope); /// Is `decl` usable as a static member? bool isDeclUsableAsStaticMember(Decl* decl); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 6c0a7f184..41aba2674 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -2249,6 +2249,24 @@ DeclRef SemanticsVisitor::inferGenericArguments( return trySolveConstraintSystem(&constraints, genericDeclRef, knownGenericArgs, outBaseCost); } +LookupResult SemanticsVisitor::lookupConstructorsInType(Type* type, Scope* sourceScope) +{ + // Look up all the initializers on `type` by looking up + // its members named `$init`. All `__init` declarations are stored + // with the name `$init` internally to avoid potential conflicts + // if a user decided to name a field/method `__init`. + LookupOptions options = + LookupOptions(uint8_t(LookupOptions::IgnoreInheritance) | uint8_t(LookupOptions::NoDeref)); + return lookUpMember( + m_astBuilder, + this, + getName("$init"), + type, + sourceScope, + LookupMask::Default, + options); +} + void SemanticsVisitor::AddTypeOverloadCandidates(Type* type, OverloadResolveContext& context) { // The code being checked is trying to apply `type` like a function. @@ -2272,16 +2290,7 @@ void SemanticsVisitor::AddTypeOverloadCandidates(Type* type, OverloadResolveCont // from a value of the same type. There is no need in Slang for // "copy constructors" but the core module currently has to define // some just to make code that does, e.g., `float(1.0f)` work.) - LookupOptions options = - LookupOptions(uint8_t(LookupOptions::IgnoreInheritance) | uint8_t(LookupOptions::NoDeref)); - LookupResult initializers = lookUpMember( - m_astBuilder, - this, - getName("$init"), - type, - context.sourceScope, - LookupMask::Default, - options); + LookupResult initializers = lookupConstructorsInType(type, context.sourceScope); AddOverloadCandidates(initializers, context); } @@ -2702,6 +2711,12 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) expr->arguments[0], &tempSink, &conversionCost); + if (auto resultInvokeExpr = as(resultExpr)) + { + resultInvokeExpr->originalFunctionExpr = expr->functionExpr; + resultInvokeExpr->argumentDelimeterLocs = expr->argumentDelimeterLocs; + resultInvokeExpr->loc = expr->loc; + } if (coerceResult) return resultExpr; typeOverloadChecked = true; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 60f6cc92f..7cdd1614c 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -3764,6 +3764,8 @@ public: ComPtr getAutodiffLibraryCode(); ComPtr getGLSLLibraryCode(); + void getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName); + RefPtr m_sharedASTBuilder; SPIRVCoreGrammarInfo& getSPIRVCoreGrammarInfo() diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index f394e7d7a..1ae9e4a0b 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -235,16 +235,6 @@ public: return dispatchIfNotNull(expr->originalExpr); } - bool visitTypeCastExpr(TypeCastExpr* expr) - { - if (dispatchIfNotNull(expr->functionExpr)) - return true; - for (auto arg : expr->arguments) - if (dispatchIfNotNull(arg)) - return true; - return false; - } - bool visitDerefExpr(DerefExpr* expr) { return dispatchIfNotNull(expr->base); } bool visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr) { diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index d52c5d855..ba2722fae 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -565,6 +565,37 @@ HumaneSourceLoc getModuleLoc(SourceManager* manager, ContainerDecl* moduleDecl) return location; } +// When user code has `Foo(123)` where `Foo` is a `struct`, goto-definition on +// `Foo` should redirect to the constructor of `Foo` instead of the type declaration of `Foo`. +// This function will check if the `declRefExpr` is a reference to a type declaration, +// but the declRefExpr is referenced from an `InvokeExpr::originalFunctionExpr` that is now +// resolved to a constructor. If so we will return the declRef of the constructor. +// +DeclRef maybeRedirectToConstructor(DeclRefExpr* declRefExpr, const List& path) +{ + if (path.getCount() < 2) + return declRefExpr->declRef; + if (!as(declRefExpr->declRef)) + return declRefExpr->declRef; + auto invokeExpr = as(path[path.getCount() - 2]); + if (!invokeExpr) + return declRefExpr->declRef; + if (!invokeExpr->originalFunctionExpr) + return declRefExpr->declRef; + auto originalFuncExpr = invokeExpr->originalFunctionExpr; + if (originalFuncExpr != declRefExpr) + return declRefExpr->declRef; + // If the invoke expression is the same as the decl ref expression, + // it means we are looking at a constructor call. + auto resolvedFuncExpr = as(invokeExpr->functionExpr); + if (!resolvedFuncExpr) + return declRefExpr->declRef; + auto ctorDecl = as(resolvedFuncExpr->declRef); + if (ctorDecl) + return ctorDecl; + return declRefExpr->declRef; +} + SlangResult LanguageServer::hover( const LanguageServerProtocol::HoverParams& args, const JSONValue& responseId) @@ -828,7 +859,8 @@ LanguageServerResult LanguageServerCore::hover( }; if (auto declRefExpr = as(leafNode)) { - fillDeclRefHoverInfo(declRefExpr->declRef, declRefExpr->name); + auto resolvedDeclRef = maybeRedirectToConstructor(declRefExpr, findResult[0].path); + fillDeclRefHoverInfo(resolvedDeclRef, declRefExpr->name); } else if (auto overloadedExpr = as(leafNode)) { @@ -1004,11 +1036,12 @@ LanguageServerResult> LanguageServerCore: { if (declRefExpr->declRef.getDecl()) { + auto declRef = declRefExpr->declRef; + declRef = maybeRedirectToConstructor(declRefExpr, findResult[0].path); auto location = version->linkage->getSourceManager()->getHumaneLoc( - declRefExpr->declRef.getNameLoc().isValid() ? declRefExpr->declRef.getNameLoc() - : declRefExpr->declRef.getLoc(), + declRef.getNameLoc().isValid() ? declRef.getNameLoc() : declRef.getLoc(), SourceLocType::Actual); - auto name = declRefExpr->declRef.getName(); + auto name = declRef.getName(); locations.add(LocationResult{ location, name ? (int)UTF8Util::calcUTF16CharCount(name->text.getUnownedSlice()) : 0}); @@ -1076,6 +1109,14 @@ LanguageServerResult> LanguageServerCore: { result.uri = URI::fromLocalFilePath(loc.loc.pathInfo.foundPath.getUnownedSlice()).uri; + } + else if (loc.loc.pathInfo.getName() == "core" || loc.loc.pathInfo.getName() == "glsl") + { + result.uri = StringBuilder() << "slang-synth://" << loc.loc.pathInfo.getName() + << "/" << loc.loc.pathInfo.getName() << ".builtin"; + } + if (result.uri.getLength() != 0) + { doc->oneBasedUTF8LocToZeroBasedUTF16Loc( loc.loc.line, loc.loc.column, @@ -1504,6 +1545,75 @@ SlangResult LanguageServer::signatureHelp( return SLANG_OK; } +// Heuristical cost for determining the best candidate to use as the active signature. +// We will always use the candidate that has the most matched parameters to the current argument +// list. If there are multiple candidates with the same number of matched parameters, we will +// use the one with the least number of unmatched parameters. If there are still multiple +// candidates with the same number of unmatched parameters, we will use the one with the least +// maximum argument conversion cost. +// +struct CallCandidateMatchCost +{ + Index matchedArgCount = 0; + Index excessArgCount = 0; + Index unmatchedParamCount = 0; + ConversionCost maxArgConversionCost = 0; + + bool isBetterThan(const CallCandidateMatchCost& other) const + { + if (excessArgCount < other.excessArgCount) + return true; + else if (excessArgCount > other.excessArgCount) + return false; + if (matchedArgCount > other.matchedArgCount) + return true; + else if (matchedArgCount < other.matchedArgCount) + return false; + + if (unmatchedParamCount < other.unmatchedParamCount) + return true; + else if (unmatchedParamCount > other.unmatchedParamCount) + return false; + return maxArgConversionCost < other.maxArgConversionCost; + } +}; + +// Given a callable decl and an AppExprBase containing the arguments used to call it, +// return the match cost for the candidate. +static CallCandidateMatchCost getCallCandidateMatchCost( + DeclRef callableDeclRef, + AppExprBase* appExpr, + SemanticsVisitor& semanticsVisitor, + WorkspaceVersion* version) +{ + CallCandidateMatchCost result; + auto astBuilder = version->linkage->getASTBuilder(); + auto paramList = getMembersOfType(astBuilder, callableDeclRef).toArray(); + + for (Index argId = 0; argId < appExpr->arguments.getCount(); argId++) + { + auto arg = appExpr->arguments[argId]; + if (!arg) + continue; + if (!arg->type.type) + continue; + if (argId < paramList.getCount()) + { + auto paramType = getType(version->linkage->getASTBuilder(), paramList[argId]); + ConversionCost argCost = 0; + if (paramType && semanticsVisitor.canCoerce(paramType, arg->type.type, arg, &argCost)) + { + result.matchedArgCount++; + result.maxArgConversionCost = Math::Max(result.maxArgConversionCost, argCost); + } + } + } + result.excessArgCount = + Math::Max((Index)0, (appExpr->argumentDelimeterLocs.getCount() - 1) - paramList.getCount()); + result.unmatchedParamCount = paramList.getCount() - result.matchedArgCount; + return result; +} + LanguageServerResult LanguageServerCore::signatureHelp( const LanguageServerProtocol::SignatureHelpParams& args) { @@ -1594,11 +1704,41 @@ LanguageServerResult LanguageServerCore:: } SignatureHelp response; + response.activeSignature = 0; + + CallCandidateMatchCost bestCandidateMatchCost; + + // We will use an ad-hoc semantics visitor to check for argument-to-parameter conversions + // and to determine the best candidate signature. + // In the ideal design, this info should be gathered during the normal type checking + // process, but that require a lot of refactoring in the current code base, and may + // risk slowing down type checking for non-language-server use cases since we won't be + // able to do as many early returns. + // So instead we will do a separate ad-hoc checking here to do a best-effort guess + // on the best candidate. + // + DiagnosticSink sink; + SharedSemanticsContext semanticsContext(version->linkage, nullptr, &sink); + SemanticsVisitor semanticsVisitor(&semanticsContext); + auto addDeclRef = [&](DeclRef declRef) { if (!declRef.getDecl()) return; + // If we have a better match than the current best, we will update response.activeSignature + // to this signature. + if (auto callableDeclRef = declRef.as()) + { + auto matchCost = + getCallCandidateMatchCost(callableDeclRef, appExpr, semanticsVisitor, version); + if (matchCost.isBetterThan(bestCandidateMatchCost)) + { + bestCandidateMatchCost = matchCost; + response.activeSignature = (uint32_t)response.signatures.getCount(); + } + } + SignatureInformation sigInfo; List> paramRanges; @@ -1675,13 +1815,14 @@ LanguageServerResult LanguageServerCore:: if (auto declRefExpr = as(funcExpr)) { - if (auto aggDeclRef = as(declRefExpr->declRef)) + if (auto typeType = as(declRefExpr->type.type)) { // Look for initializers - for (auto member : - getMembersOfType(version->linkage->getASTBuilder(), aggDeclRef)) + auto ctors = + semanticsVisitor.lookupConstructorsInType(typeType->getType(), declRefExpr->scope); + for (auto ctor : ctors) { - addDeclRef(member); + addDeclRef(ctor.declRef); } } else @@ -1711,7 +1852,6 @@ LanguageServerResult LanguageServerCore:: { addFuncType(funcType); } - response.activeSignature = 0; response.activeParameter = 0; for (int i = 1; i < appExpr->argumentDelimeterLocs.getCount(); i++) { @@ -2828,4 +2968,23 @@ SLANG_API SlangResult runLanguageServer(Slang::LanguageServerStartupOptions opti return SLANG_OK; } +SLANG_API SlangResult +getBuiltinModuleSource(const UnownedStringSlice& moduleName, slang::IBlob** blob) +{ + ComPtr globalSession; + slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()); + Slang::Session* session = static_cast(globalSession.get()); + StringBuilder sb; + if (moduleName.startsWith("core")) + { + session->getBuiltinModuleSource(sb, slang::BuiltinModuleName::Core); + } + else if (moduleName.startsWith("glsl")) + { + session->getBuiltinModuleSource(sb, slang::BuiltinModuleName::GLSL); + } + *blob = StringBlob::moveCreate(sb.produceString()).detach(); + return SLANG_OK; +} + } // namespace Slang diff --git a/source/slang/slang-language-server.h b/source/slang/slang-language-server.h index 43c3521eb..31b7114c0 100644 --- a/source/slang/slang-language-server.h +++ b/source/slang/slang-language-server.h @@ -275,4 +275,7 @@ inline bool _isIdentifierChar(char ch) } SLANG_API SlangResult runLanguageServer(LanguageServerStartupOptions options); +SLANG_API SlangResult +getBuiltinModuleSource(const UnownedStringSlice& moduleName, slang::IBlob** blob); + } // namespace Slang diff --git a/source/slang/slang-workspace-version.cpp b/source/slang/slang-workspace-version.cpp index b1a3dec34..63bf7ed53 100644 --- a/source/slang/slang-workspace-version.cpp +++ b/source/slang/slang-workspace-version.cpp @@ -487,7 +487,10 @@ void DocumentVersion::oneBasedUTF8LocToZeroBasedUTF16Loc( Index rsLine = inLine - 1; auto bounds = getUTF16Boundaries(inLine); outLine = rsLine; - outCol = std::lower_bound(bounds.begin(), bounds.end(), inCol - 1) - bounds.begin(); + if (bounds.getCount() != 0) + outCol = std::lower_bound(bounds.begin(), bounds.end(), inCol - 1) - bounds.begin(); + else + outCol = inCol - 1; } void DocumentVersion::oneBasedUTF8LocToZeroBasedUTF16Loc( diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 297a3464f..9bfc2bce9 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -435,6 +435,21 @@ SlangResult Session::compileCoreModule(slang::CompileCoreModuleFlags compileFlag return compileBuiltinModule(slang::BuiltinModuleName::Core, compileFlags); } +void Session::getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName) +{ + switch (moduleName) + { + case slang::BuiltinModuleName::Core: + sb << (const char*)getCoreLibraryCode()->getBufferPointer() + << (const char*)getHLSLLibraryCode()->getBufferPointer() + << (const char*)getAutodiffLibraryCode()->getBufferPointer(); + break; + case slang::BuiltinModuleName::GLSL: + sb << (const char*)getGLSLLibraryCode()->getBufferPointer(); + break; + } +} + SlangResult Session::compileBuiltinModule( slang::BuiltinModuleName moduleName, slang::CompileCoreModuleFlags compileFlags) @@ -460,17 +475,7 @@ SlangResult Session::compileBuiltinModule( } StringBuilder moduleSrcBuilder; - switch (moduleName) - { - case slang::BuiltinModuleName::Core: - moduleSrcBuilder << (const char*)getCoreLibraryCode()->getBufferPointer() - << (const char*)getHLSLLibraryCode()->getBufferPointer() - << (const char*)getAutodiffLibraryCode()->getBufferPointer(); - break; - case slang::BuiltinModuleName::GLSL: - moduleSrcBuilder << (const char*)getGLSLLibraryCode()->getBufferPointer(); - break; - } + getBuiltinModuleSource(moduleSrcBuilder, moduleName); // TODO(JS): Could make this return a SlangResult as opposed to exception auto moduleSrcBlob = StringBlob::moveCreate(moduleSrcBuilder.produceString()); -- cgit v1.2.3