diff options
| author | Yong He <yonghe@outlook.com> | 2025-07-03 15:20:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-03 22:20:23 +0000 |
| commit | b4fc380af5e390ca11892f9e657e653f6869c21b (patch) | |
| tree | 9072841ed14a190cce0790ced27b283f85d1fc4f /source | |
| parent | 551d0c365571a2e36505851f6a713464662c5fea (diff) | |
Language Server Enhancements (#7604)
* Language Server: auto-select the best candidate in signature help.
* Fix constructor call highlighting + goto definition.
* Add test.
* format code
* Improve ctor signature help.
* Add tests.
* Fix decl path printing for extension children.
* Allow goto definition to show core module source.
* c++ compile fix.
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-iterator.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ast-print.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 35 | ||||
| -rw-r--r-- | source/slang/slang-compiler.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-language-server-ast-lookup.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-language-server.cpp | 177 | ||||
| -rw-r--r-- | source/slang/slang-language-server.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-workspace-version.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 27 |
10 files changed, 224 insertions, 44 deletions
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 094a9c1a2..3cce8df59 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -133,6 +133,7 @@ struct ASTIterator iterator->maybeDispatchCallback(expr); dispatchIfNotNull(expr->functionExpr); + dispatchIfNotNull(expr->originalFunctionExpr); for (auto arg : expr->arguments) dispatchIfNotNull(arg); } diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index 4b5a69f15..ee747a4c2 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -1129,6 +1129,7 @@ void ASTPrinter::addVal(Val* val) /* static */ void ASTPrinter::appendDeclName(Decl* decl, StringBuilder& out) { + decl = maybeGetInner(decl); if (as<ConstructorDecl>(decl)) { out << "init"; @@ -1231,8 +1232,7 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth) } else if (auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>()) { - ExtensionDecl* extensionDecl = as<ExtensionDecl>(parentDeclRef.getDecl()); - Type* type = extensionDecl->targetType.type; + Type* type = getTargetType(m_astBuilder, extensionDeclRef); if (m_optionFlags & OptionFlag::NoSpecializedExtensionTypeName) { if (auto unspecializedDeclRef = isDeclRefTypeOf<Decl>(type)) @@ -1522,6 +1522,8 @@ void ASTPrinter::addDeclKindPrefix(Decl* decl) continue; if (as<HLSLLayoutSemantic>(modifier)) continue; + if (as<ImplicitConversionModifier>(modifier)) + continue; } // Don't print out attributes. if (as<AttributeBase>(modifier)) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 7ddec20fb..30e317401 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1252,8 +1252,8 @@ public: TypeExp TranslateTypeNode(TypeExp const& typeExp); Type* getRemovedModifierType(ModifiedType* type, ModifierVal* modifier); Type* getConstantBufferType(Type* elementType, Type* layoutType); - DeclRefType* getExprDeclRefType(Expr* expr); + LookupResult lookupConstructorsInType(Type* type, Scope* sourceScope); /// Is `decl` usable as a static member? bool isDeclUsableAsStaticMember(Decl* decl); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 6c0a7f184..41aba2674 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -2249,6 +2249,24 @@ DeclRef<Decl> SemanticsVisitor::inferGenericArguments( return trySolveConstraintSystem(&constraints, genericDeclRef, knownGenericArgs, outBaseCost); } +LookupResult SemanticsVisitor::lookupConstructorsInType(Type* type, Scope* sourceScope) +{ + // Look up all the initializers on `type` by looking up + // its members named `$init`. All `__init` declarations are stored + // with the name `$init` internally to avoid potential conflicts + // if a user decided to name a field/method `__init`. + LookupOptions options = + LookupOptions(uint8_t(LookupOptions::IgnoreInheritance) | uint8_t(LookupOptions::NoDeref)); + return lookUpMember( + m_astBuilder, + this, + getName("$init"), + type, + sourceScope, + LookupMask::Default, + options); +} + void SemanticsVisitor::AddTypeOverloadCandidates(Type* type, OverloadResolveContext& context) { // The code being checked is trying to apply `type` like a function. @@ -2272,16 +2290,7 @@ void SemanticsVisitor::AddTypeOverloadCandidates(Type* type, OverloadResolveCont // from a value of the same type. There is no need in Slang for // "copy constructors" but the core module currently has to define // some just to make code that does, e.g., `float(1.0f)` work.) - LookupOptions options = - LookupOptions(uint8_t(LookupOptions::IgnoreInheritance) | uint8_t(LookupOptions::NoDeref)); - LookupResult initializers = lookUpMember( - m_astBuilder, - this, - getName("$init"), - type, - context.sourceScope, - LookupMask::Default, - options); + LookupResult initializers = lookupConstructorsInType(type, context.sourceScope); AddOverloadCandidates(initializers, context); } @@ -2702,6 +2711,12 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) expr->arguments[0], &tempSink, &conversionCost); + if (auto resultInvokeExpr = as<InvokeExpr>(resultExpr)) + { + resultInvokeExpr->originalFunctionExpr = expr->functionExpr; + resultInvokeExpr->argumentDelimeterLocs = expr->argumentDelimeterLocs; + resultInvokeExpr->loc = expr->loc; + } if (coerceResult) return resultExpr; typeOverloadChecked = true; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 60f6cc92f..7cdd1614c 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -3764,6 +3764,8 @@ public: ComPtr<ISlangBlob> getAutodiffLibraryCode(); ComPtr<ISlangBlob> getGLSLLibraryCode(); + void getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName); + RefPtr<SharedASTBuilder> m_sharedASTBuilder; SPIRVCoreGrammarInfo& getSPIRVCoreGrammarInfo() diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index f394e7d7a..1ae9e4a0b 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -235,16 +235,6 @@ public: return dispatchIfNotNull(expr->originalExpr); } - bool visitTypeCastExpr(TypeCastExpr* expr) - { - if (dispatchIfNotNull(expr->functionExpr)) - return true; - for (auto arg : expr->arguments) - if (dispatchIfNotNull(arg)) - return true; - return false; - } - bool visitDerefExpr(DerefExpr* expr) { return dispatchIfNotNull(expr->base); } bool visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr) { diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index d52c5d855..ba2722fae 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -565,6 +565,37 @@ HumaneSourceLoc getModuleLoc(SourceManager* manager, ContainerDecl* moduleDecl) return location; } +// When user code has `Foo(123)` where `Foo` is a `struct`, goto-definition on +// `Foo` should redirect to the constructor of `Foo` instead of the type declaration of `Foo`. +// This function will check if the `declRefExpr` is a reference to a type declaration, +// but the declRefExpr is referenced from an `InvokeExpr::originalFunctionExpr` that is now +// resolved to a constructor. If so we will return the declRef of the constructor. +// +DeclRef<Decl> maybeRedirectToConstructor(DeclRefExpr* declRefExpr, const List<SyntaxNode*>& path) +{ + if (path.getCount() < 2) + return declRefExpr->declRef; + if (!as<AggTypeDecl>(declRefExpr->declRef)) + return declRefExpr->declRef; + auto invokeExpr = as<InvokeExpr>(path[path.getCount() - 2]); + if (!invokeExpr) + return declRefExpr->declRef; + if (!invokeExpr->originalFunctionExpr) + return declRefExpr->declRef; + auto originalFuncExpr = invokeExpr->originalFunctionExpr; + if (originalFuncExpr != declRefExpr) + return declRefExpr->declRef; + // If the invoke expression is the same as the decl ref expression, + // it means we are looking at a constructor call. + auto resolvedFuncExpr = as<DeclRefExpr>(invokeExpr->functionExpr); + if (!resolvedFuncExpr) + return declRefExpr->declRef; + auto ctorDecl = as<ConstructorDecl>(resolvedFuncExpr->declRef); + if (ctorDecl) + return ctorDecl; + return declRefExpr->declRef; +} + SlangResult LanguageServer::hover( const LanguageServerProtocol::HoverParams& args, const JSONValue& responseId) @@ -828,7 +859,8 @@ LanguageServerResult<LanguageServerProtocol::Hover> LanguageServerCore::hover( }; if (auto declRefExpr = as<DeclRefExpr>(leafNode)) { - fillDeclRefHoverInfo(declRefExpr->declRef, declRefExpr->name); + auto resolvedDeclRef = maybeRedirectToConstructor(declRefExpr, findResult[0].path); + fillDeclRefHoverInfo(resolvedDeclRef, declRefExpr->name); } else if (auto overloadedExpr = as<OverloadedExpr>(leafNode)) { @@ -1004,11 +1036,12 @@ LanguageServerResult<List<LanguageServerProtocol::Location>> LanguageServerCore: { if (declRefExpr->declRef.getDecl()) { + auto declRef = declRefExpr->declRef; + declRef = maybeRedirectToConstructor(declRefExpr, findResult[0].path); auto location = version->linkage->getSourceManager()->getHumaneLoc( - declRefExpr->declRef.getNameLoc().isValid() ? declRefExpr->declRef.getNameLoc() - : declRefExpr->declRef.getLoc(), + declRef.getNameLoc().isValid() ? declRef.getNameLoc() : declRef.getLoc(), SourceLocType::Actual); - auto name = declRefExpr->declRef.getName(); + auto name = declRef.getName(); locations.add(LocationResult{ location, name ? (int)UTF8Util::calcUTF16CharCount(name->text.getUnownedSlice()) : 0}); @@ -1076,6 +1109,14 @@ LanguageServerResult<List<LanguageServerProtocol::Location>> LanguageServerCore: { result.uri = URI::fromLocalFilePath(loc.loc.pathInfo.foundPath.getUnownedSlice()).uri; + } + else if (loc.loc.pathInfo.getName() == "core" || loc.loc.pathInfo.getName() == "glsl") + { + result.uri = StringBuilder() << "slang-synth://" << loc.loc.pathInfo.getName() + << "/" << loc.loc.pathInfo.getName() << ".builtin"; + } + if (result.uri.getLength() != 0) + { doc->oneBasedUTF8LocToZeroBasedUTF16Loc( loc.loc.line, loc.loc.column, @@ -1504,6 +1545,75 @@ SlangResult LanguageServer::signatureHelp( return SLANG_OK; } +// Heuristical cost for determining the best candidate to use as the active signature. +// We will always use the candidate that has the most matched parameters to the current argument +// list. If there are multiple candidates with the same number of matched parameters, we will +// use the one with the least number of unmatched parameters. If there are still multiple +// candidates with the same number of unmatched parameters, we will use the one with the least +// maximum argument conversion cost. +// +struct CallCandidateMatchCost +{ + Index matchedArgCount = 0; + Index excessArgCount = 0; + Index unmatchedParamCount = 0; + ConversionCost maxArgConversionCost = 0; + + bool isBetterThan(const CallCandidateMatchCost& other) const + { + if (excessArgCount < other.excessArgCount) + return true; + else if (excessArgCount > other.excessArgCount) + return false; + if (matchedArgCount > other.matchedArgCount) + return true; + else if (matchedArgCount < other.matchedArgCount) + return false; + + if (unmatchedParamCount < other.unmatchedParamCount) + return true; + else if (unmatchedParamCount > other.unmatchedParamCount) + return false; + return maxArgConversionCost < other.maxArgConversionCost; + } +}; + +// Given a callable decl and an AppExprBase containing the arguments used to call it, +// return the match cost for the candidate. +static CallCandidateMatchCost getCallCandidateMatchCost( + DeclRef<CallableDecl> callableDeclRef, + AppExprBase* appExpr, + SemanticsVisitor& semanticsVisitor, + WorkspaceVersion* version) +{ + CallCandidateMatchCost result; + auto astBuilder = version->linkage->getASTBuilder(); + auto paramList = getMembersOfType<ParamDecl>(astBuilder, callableDeclRef).toArray(); + + for (Index argId = 0; argId < appExpr->arguments.getCount(); argId++) + { + auto arg = appExpr->arguments[argId]; + if (!arg) + continue; + if (!arg->type.type) + continue; + if (argId < paramList.getCount()) + { + auto paramType = getType(version->linkage->getASTBuilder(), paramList[argId]); + ConversionCost argCost = 0; + if (paramType && semanticsVisitor.canCoerce(paramType, arg->type.type, arg, &argCost)) + { + result.matchedArgCount++; + result.maxArgConversionCost = Math::Max(result.maxArgConversionCost, argCost); + } + } + } + result.excessArgCount = + Math::Max((Index)0, (appExpr->argumentDelimeterLocs.getCount() - 1) - paramList.getCount()); + result.unmatchedParamCount = paramList.getCount() - result.matchedArgCount; + return result; +} + LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore::signatureHelp( const LanguageServerProtocol::SignatureHelpParams& args) { @@ -1594,11 +1704,41 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore:: } SignatureHelp response; + response.activeSignature = 0; + + CallCandidateMatchCost bestCandidateMatchCost; + + // We will use an ad-hoc semantics visitor to check for argument-to-parameter conversions + // and to determine the best candidate signature. + // In the ideal design, this info should be gathered during the normal type checking + // process, but that require a lot of refactoring in the current code base, and may + // risk slowing down type checking for non-language-server use cases since we won't be + // able to do as many early returns. + // So instead we will do a separate ad-hoc checking here to do a best-effort guess + // on the best candidate. + // + DiagnosticSink sink; + SharedSemanticsContext semanticsContext(version->linkage, nullptr, &sink); + SemanticsVisitor semanticsVisitor(&semanticsContext); + auto addDeclRef = [&](DeclRef<Decl> declRef) { if (!declRef.getDecl()) return; + // If we have a better match than the current best, we will update response.activeSignature + // to this signature. + if (auto callableDeclRef = declRef.as<CallableDecl>()) + { + auto matchCost = + getCallCandidateMatchCost(callableDeclRef, appExpr, semanticsVisitor, version); + if (matchCost.isBetterThan(bestCandidateMatchCost)) + { + bestCandidateMatchCost = matchCost; + response.activeSignature = (uint32_t)response.signatures.getCount(); + } + } + SignatureInformation sigInfo; List<Slang::Range<Index>> paramRanges; @@ -1675,13 +1815,14 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore:: if (auto declRefExpr = as<DeclRefExpr>(funcExpr)) { - if (auto aggDeclRef = as<AggTypeDecl>(declRefExpr->declRef)) + if (auto typeType = as<TypeType>(declRefExpr->type.type)) { // Look for initializers - for (auto member : - getMembersOfType<ConstructorDecl>(version->linkage->getASTBuilder(), aggDeclRef)) + auto ctors = + semanticsVisitor.lookupConstructorsInType(typeType->getType(), declRefExpr->scope); + for (auto ctor : ctors) { - addDeclRef(member); + addDeclRef(ctor.declRef); } } else @@ -1711,7 +1852,6 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore:: { addFuncType(funcType); } - response.activeSignature = 0; response.activeParameter = 0; for (int i = 1; i < appExpr->argumentDelimeterLocs.getCount(); i++) { @@ -2828,4 +2968,23 @@ SLANG_API SlangResult runLanguageServer(Slang::LanguageServerStartupOptions opti return SLANG_OK; } +SLANG_API SlangResult +getBuiltinModuleSource(const UnownedStringSlice& moduleName, slang::IBlob** blob) +{ + ComPtr<slang::IGlobalSession> globalSession; + slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()); + Slang::Session* session = static_cast<Slang::Session*>(globalSession.get()); + StringBuilder sb; + if (moduleName.startsWith("core")) + { + session->getBuiltinModuleSource(sb, slang::BuiltinModuleName::Core); + } + else if (moduleName.startsWith("glsl")) + { + session->getBuiltinModuleSource(sb, slang::BuiltinModuleName::GLSL); + } + *blob = StringBlob::moveCreate(sb.produceString()).detach(); + return SLANG_OK; +} + } // namespace Slang diff --git a/source/slang/slang-language-server.h b/source/slang/slang-language-server.h index 43c3521eb..31b7114c0 100644 --- a/source/slang/slang-language-server.h +++ b/source/slang/slang-language-server.h @@ -275,4 +275,7 @@ inline bool _isIdentifierChar(char ch) } SLANG_API SlangResult runLanguageServer(LanguageServerStartupOptions options); +SLANG_API SlangResult +getBuiltinModuleSource(const UnownedStringSlice& moduleName, slang::IBlob** blob); + } // namespace Slang diff --git a/source/slang/slang-workspace-version.cpp b/source/slang/slang-workspace-version.cpp index b1a3dec34..63bf7ed53 100644 --- a/source/slang/slang-workspace-version.cpp +++ b/source/slang/slang-workspace-version.cpp @@ -487,7 +487,10 @@ void DocumentVersion::oneBasedUTF8LocToZeroBasedUTF16Loc( Index rsLine = inLine - 1; auto bounds = getUTF16Boundaries(inLine); outLine = rsLine; - outCol = std::lower_bound(bounds.begin(), bounds.end(), inCol - 1) - bounds.begin(); + if (bounds.getCount() != 0) + outCol = std::lower_bound(bounds.begin(), bounds.end(), inCol - 1) - bounds.begin(); + else + outCol = inCol - 1; } void DocumentVersion::oneBasedUTF8LocToZeroBasedUTF16Loc( diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 297a3464f..9bfc2bce9 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -435,6 +435,21 @@ SlangResult Session::compileCoreModule(slang::CompileCoreModuleFlags compileFlag return compileBuiltinModule(slang::BuiltinModuleName::Core, compileFlags); } +void Session::getBuiltinModuleSource(StringBuilder& sb, slang::BuiltinModuleName moduleName) +{ + switch (moduleName) + { + case slang::BuiltinModuleName::Core: + sb << (const char*)getCoreLibraryCode()->getBufferPointer() + << (const char*)getHLSLLibraryCode()->getBufferPointer() + << (const char*)getAutodiffLibraryCode()->getBufferPointer(); + break; + case slang::BuiltinModuleName::GLSL: + sb << (const char*)getGLSLLibraryCode()->getBufferPointer(); + break; + } +} + SlangResult Session::compileBuiltinModule( slang::BuiltinModuleName moduleName, slang::CompileCoreModuleFlags compileFlags) @@ -460,17 +475,7 @@ SlangResult Session::compileBuiltinModule( } StringBuilder moduleSrcBuilder; - switch (moduleName) - { - case slang::BuiltinModuleName::Core: - moduleSrcBuilder << (const char*)getCoreLibraryCode()->getBufferPointer() - << (const char*)getHLSLLibraryCode()->getBufferPointer() - << (const char*)getAutodiffLibraryCode()->getBufferPointer(); - break; - case slang::BuiltinModuleName::GLSL: - moduleSrcBuilder << (const char*)getGLSLLibraryCode()->getBufferPointer(); - break; - } + getBuiltinModuleSource(moduleSrcBuilder, moduleName); // TODO(JS): Could make this return a SlangResult as opposed to exception auto moduleSrcBlob = StringBlob::moveCreate(moduleSrcBuilder.produceString()); |
