diff options
| author | Yong He <yonghe@outlook.com> | 2022-06-13 12:20:35 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-13 12:20:35 -0700 |
| commit | c90c6ab750ab05dd6d337e4f857958b8f3d00153 (patch) | |
| tree | 569085637c5d3de33d7aaec8ab8c0e84be49bfd0 /source | |
| parent | 68d9d87f9385a8c7c29443dcfcbf70434dc889bd (diff) | |
Language Server improvements. (#2269)
* Language Server improvements.
- Improve parser robustness around `attribute_syntax`.
- Exclude instance members in a static query.
- Coloring accessors
- Improved signature help cursor range check.
* Add expected test result.
* Language server: support configuring predefined macros.
* Fix constructor highlighting.
* Improving performance by supporting incremental text change notifications.
* Fix UTF16 positions and highlighting of constructor calls.
* Add completion suggestions for HLSL semantics.
* Fix tests.
* Fix: don't skip static variables in a static query.
* Include literal init expr value in hover text.
* Fix scenarios where completion failed to trigger.
* Fixing language server protocol field initializations.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
28 files changed, 1787 insertions, 674 deletions
diff --git a/source/compiler-core/slang-diagnostic-sink.cpp b/source/compiler-core/slang-diagnostic-sink.cpp index 0110b16d7..c934157d2 100644 --- a/source/compiler-core/slang-diagnostic-sink.cpp +++ b/source/compiler-core/slang-diagnostic-sink.cpp @@ -570,7 +570,7 @@ void DiagnosticSink::diagnoseImpl(DiagnosticInfo const& info, const UnownedStrin m_parentSink->diagnoseImpl(info, formattedMessage); } - if (info.severity >= Severity::Fatal) + if (getEffectiveMessageSeverity(info) >= Severity::Fatal) { // TODO: figure out a better policy for aborting compilation SLANG_ABORT_COMPILATION(""); diff --git a/source/compiler-core/slang-language-server-protocol.cpp b/source/compiler-core/slang-language-server-protocol.cpp index 3e8907b9a..713d209dd 100644 --- a/source/compiler-core/slang-language-server-protocol.cpp +++ b/source/compiler-core/slang-language-server-protocol.cpp @@ -181,12 +181,35 @@ const StructRttiInfo DidCloseTextDocumentParams::g_rttiInfo = _makeDidCloseTextD const UnownedStringSlice DidCloseTextDocumentParams::methodName = UnownedStringSlice::fromLiteral("textDocument/didClose"); +static const StructRttiInfo _makeWorkspaceFoldersServerCapabilitiesRtti() +{ + WorkspaceFoldersServerCapabilities obj; + StructRttiBuilder builder(&obj, "LanguageServerProtocol::WorkspaceFoldersServerCapabilities", nullptr); + builder.addField("supported", &obj.supported); + builder.addField("changeNotifications", &obj.changeNotifications); + builder.ignoreUnknownFields(); + return builder.make(); +} +const StructRttiInfo WorkspaceFoldersServerCapabilities::g_rttiInfo = + _makeWorkspaceFoldersServerCapabilitiesRtti(); + +static const StructRttiInfo _makeWorkspaceCapabilitiesRtti() +{ + WorkspaceCapabilities obj; + StructRttiBuilder builder(&obj, "LanguageServerProtocol::WorkspaceCapabilities", nullptr); + builder.addField("workspaceFolders", &obj.workspaceFolders); + builder.ignoreUnknownFields(); + return builder.make(); +} +const StructRttiInfo WorkspaceCapabilities::g_rttiInfo = _makeWorkspaceCapabilitiesRtti(); + static const StructRttiInfo _makeServerCapabilitiesRtti() { ServerCapabilities obj; StructRttiBuilder builder(&obj, "LanguageServerProtocol::ServerCapabilities", nullptr); builder.addField("positionEncoding", &obj.positionEncoding); builder.addField("textDocumentSync", &obj.textDocumentSync); + builder.addField("workspace", &obj.workspace); builder.addField("hoverProvider", &obj.hoverProvider); builder.addField("definitionProvider", &obj.definitionProvider); builder.addField("completionProvider", &obj.completionProvider); @@ -470,6 +493,87 @@ static const StructRttiInfo _makeSignatureHelpRtti() } const StructRttiInfo SignatureHelp::g_rttiInfo = _makeSignatureHelpRtti(); +static const StructRttiInfo _makeDidChangeConfigurationParamsRtti() +{ + DidChangeConfigurationParams obj; + StructRttiBuilder builder(&obj, "LanguageServerProtocol::DidChangeConfigurationParams", nullptr); + builder.addField("settings", &obj.settings, StructRttiInfo::Flag::Optional); + builder.ignoreUnknownFields(); + return builder.make(); +} +const StructRttiInfo DidChangeConfigurationParams::g_rttiInfo = + _makeDidChangeConfigurationParamsRtti(); +const UnownedStringSlice DidChangeConfigurationParams::methodName = + UnownedStringSlice::fromLiteral("workspace/didChangeConfiguration"); + +static const StructRttiInfo _makeConfigurationItemRtti() +{ + ConfigurationItem obj; + StructRttiBuilder builder( + &obj, "LanguageServerProtocol::ConfigurationItem", nullptr); + builder.addField("section", &obj.section, StructRttiInfo::Flag::Optional); + builder.ignoreUnknownFields(); + return builder.make(); +} +const StructRttiInfo ConfigurationItem::g_rttiInfo = _makeConfigurationItemRtti(); + +static const StructRttiInfo _makeConfigurationParamsRtti() +{ + ConfigurationParams obj; + StructRttiBuilder builder( + &obj, "LanguageServerProtocol::ConfigurationParams", nullptr); + builder.addField("items", &obj.items, StructRttiInfo::Flag::Optional); + builder.ignoreUnknownFields(); + return builder.make(); +} +const StructRttiInfo ConfigurationParams::g_rttiInfo = _makeConfigurationParamsRtti(); +const UnownedStringSlice ConfigurationParams::methodName = + UnownedStringSlice::fromLiteral("workspace/configuration"); + +static const StructRttiInfo _makeRegistrationRtti() +{ + Registration obj; + StructRttiBuilder builder(&obj, "LanguageServerProtocol::Registration", nullptr); + builder.addField("id", &obj.id, StructRttiInfo::Flag::Optional); + builder.addField("method", &obj.method, StructRttiInfo::Flag::Optional); + builder.ignoreUnknownFields(); + return builder.make(); +} +const StructRttiInfo Registration::g_rttiInfo = _makeRegistrationRtti(); + +static const StructRttiInfo _makeRegistrationParamsRtti() +{ + RegistrationParams obj; + StructRttiBuilder builder(&obj, "LanguageServerProtocol::RegistrationParams", nullptr); + builder.addField("registrations", &obj.registrations, StructRttiInfo::Flag::Optional); + builder.ignoreUnknownFields(); + return builder.make(); +} +const StructRttiInfo RegistrationParams::g_rttiInfo = _makeRegistrationParamsRtti(); + +static const StructRttiInfo _makeCancelParamsRtti() +{ + CancelParams obj; + StructRttiBuilder builder(&obj, "LanguageServerProtocol::CancelParams", nullptr); + builder.addField("id", &obj.id, StructRttiInfo::Flag::Optional); + builder.ignoreUnknownFields(); + return builder.make(); +} +const StructRttiInfo CancelParams::g_rttiInfo = _makeCancelParamsRtti(); + +static const StructRttiInfo _makeLogMessageParamsRtti() +{ + LogMessageParams obj; + StructRttiBuilder builder(&obj, "LanguageServerProtocol::LogMessageParams", nullptr); + builder.addField("type", &obj.type, StructRttiInfo::Flag::Optional); + builder.addField("message", &obj.message, StructRttiInfo::Flag::Optional); + builder.ignoreUnknownFields(); + return builder.make(); +} +const StructRttiInfo LogMessageParams::g_rttiInfo = _makeLogMessageParamsRtti(); +const UnownedStringSlice LogMessageParams::methodName = + UnownedStringSlice::fromLiteral("window/logMessage"); + } // namespace LanguageServerProtocol } diff --git a/source/compiler-core/slang-language-server-protocol.h b/source/compiler-core/slang-language-server-protocol.h index e1ac47beb..11446bd0b 100644 --- a/source/compiler-core/slang-language-server-protocol.h +++ b/source/compiler-core/slang-language-server-protocol.h @@ -205,6 +205,33 @@ struct DidCloseTextDocumentParams static const UnownedStringSlice methodName; }; +struct WorkspaceFoldersServerCapabilities +{ + /** + * The server has support for workspace folders + */ + bool supported = false; + + /** + * Whether the server wants to receive workspace folder + * change notifications. + * + * If a string is provided, the string is treated as an ID + * under which the notification is registered on the client + * side. The ID can be used to unregister for these events + * using the `client/unregisterCapability` request. + */ + bool changeNotifications = false; + + static const StructRttiInfo g_rttiInfo; +}; + +struct WorkspaceCapabilities +{ + WorkspaceFoldersServerCapabilities workspaceFolders; + static const StructRttiInfo g_rttiInfo; +}; + struct ServerCapabilities { String positionEncoding; @@ -214,6 +241,7 @@ struct ServerCapabilities CompletionOptions completionProvider; SemanticTokensOptions semanticTokensProvider; SignatureHelpOptions signatureHelpProvider; + WorkspaceCapabilities workspace; static const StructRttiInfo g_rttiInfo; }; @@ -245,11 +273,13 @@ struct InitializeResult static const StructRttiInfo g_rttiInfo; }; -struct ShutdownParams { +struct ShutdownParams +{ static const UnownedStringSlice methodName; }; -struct ExitParams { +struct ExitParams +{ static const UnownedStringSlice methodName; }; @@ -630,5 +660,85 @@ struct SignatureHelp }; +struct DidChangeConfigurationParams +{ + /** + * The actual changed settings + */ + JSONValue settings = JSONValue::makeInvalid(); + + static const StructRttiInfo g_rttiInfo; + + static const UnownedStringSlice methodName; +}; + +struct ConfigurationItem +{ + /** + * The configuration section asked for. + */ + String section; + + static const StructRttiInfo g_rttiInfo; +}; + +struct ConfigurationParams +{ + List<ConfigurationItem> items; + + static const StructRttiInfo g_rttiInfo; + + static const UnownedStringSlice methodName; +}; + +struct Registration +{ + /** + * The id used to register the request. The id can be used to deregister + * the request again. + */ + String id; + + /** + * The method / capability to register for. + */ + String method; + + static const StructRttiInfo g_rttiInfo; +}; + +struct RegistrationParams +{ + List<Registration> registrations; + + static const StructRttiInfo g_rttiInfo; +}; + +struct CancelParams +{ + /** + * The request id to cancel. + */ + int64_t id = 0; + + static const StructRttiInfo g_rttiInfo; +}; + +struct LogMessageParams +{ + /** + * The message type. See {@link MessageType} + */ + int type = 0; + + /** + * The actual message + */ + String message; + + static const StructRttiInfo g_rttiInfo; + static const UnownedStringSlice methodName; +}; + } // namespace LanguageServerProtocol } // namespace Slang diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 433d42d4a..65eaea4f1 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -16,6 +16,10 @@ class DeclGroup: public DeclBase List<Decl*> decls; }; +class UnresolvedDecl : public Decl +{ + SLANG_AST_CLASS(UnresolvedDecl) +}; // A "container" decl is a parent to other declarations class ContainerDecl: public Decl diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index eb918f0b9..e3dea7df1 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -517,7 +517,10 @@ namespace Slang return cf(astBuilder); } - SLANG_FORCE_INLINE bool isSubClassOfImpl(SyntaxClassBase const& super) const { return classInfo->isSubClassOf(*super.classInfo); } + SLANG_FORCE_INLINE bool isSubClassOfImpl(SyntaxClassBase const& super) const + { + return classInfo ? classInfo->isSubClassOf(*super.classInfo) : false; + } ReflectClassInfo const* classInfo = nullptr; }; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 6fddbc453..bc037b0c2 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -255,6 +255,12 @@ namespace Slang return m_module; } + bool isInLanguageServer() + { + if (m_linkage) + return m_linkage->isInLanguageServer(); + return false; + } /// Get the list of extension declarations that appear to apply to `decl` in this context List<ExtensionDecl*> const& getCandidateExtensionsForTypeDecl(AggTypeDecl* decl); @@ -1685,12 +1691,13 @@ namespace Slang // deal with this cases here, even if they are no-ops. // - #define CASE(NAME) \ - Expr* visit##NAME(NAME* expr) \ - { \ - SLANG_DIAGNOSE_UNEXPECTED(getSink(), expr, \ - "should not appear in input syntax"); \ - return expr; \ + #define CASE(NAME) \ + Expr* visit##NAME(NAME* expr) \ + { \ + if (!getShared()->isInLanguageServer()) \ + SLANG_DIAGNOSE_UNEXPECTED(getSink(), expr, "should not appear in input syntax"); \ + expr->type = m_astBuilder->getErrorType(); \ + return expr; \ } CASE(DerefExpr) diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 176609106..429479a38 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -594,7 +594,7 @@ namespace Slang return uncheckedAttr; } - if(!attrDecl->syntaxClass.isSubClassOf<Attribute>()) + if (!attrDecl->syntaxClass.isSubClassOf<Attribute>()) { SLANG_DIAGNOSE_UNEXPECTED(getSink(), attrDecl, "attribute declaration does not reference an attribute class"); return uncheckedAttr; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 148b0205b..280bdf688 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1671,6 +1671,14 @@ namespace Slang /// lookup additional loaded modules. typedef Dictionary<Name*, Module*> LoadedModuleDictionary; + class Linkage; + class IModuleCache + { + public: + virtual RefPtr<Module> tryLoadModule(Linkage* linkage, String filePath) = 0; + virtual void storeModule(Linkage* linkage, String filePath, RefPtr<Module> module) = 0; + }; + /// A context for loading and re-using code modules. class Linkage : public RefObject, public slang::ISession { @@ -1791,6 +1799,8 @@ namespace Slang TypeCheckingCache* m_typeCheckingCache = nullptr; + void setModuleCache(IModuleCache* cache) { m_moduleCache = cache; } + // Modules that have been dynamically loaded via `import` // // This is a list of unique modules loaded, in the order they were encountered. @@ -1925,7 +1935,7 @@ namespace Slang RefPtr<Session> m_retainedSession; - + IModuleCache* m_moduleCache = nullptr; /// Tracks state of modules currently being loaded. /// diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 768c8ef2d..3a35e699b 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -1,23 +1,12 @@ #include "slang-language-server-ast-lookup.h" #include "slang-visitor.h" +#include "slang-workspace-version.h" namespace Slang { -struct Loc -{ - Int line; - Int col; - bool operator<(const Loc& other) - { - return line < other.line || line == other.line && col < other.col; - } - bool operator<=(const Loc& other) - { - return line < other.line || line == other.line && col <= other.col; - } -}; struct ASTLookupContext { + DocumentVersion* doc; SourceManager* sourceManager; List<SyntaxNode*> nodePath; ASTLookupType findType; @@ -29,12 +18,18 @@ struct ASTLookupContext Loc getLoc(SourceLoc loc, String* outFileName) { - auto humaneLoc = sourceManager->getHumaneLoc(loc, SourceLocType::Actual); - if (outFileName) - *outFileName = humaneLoc.pathInfo.foundPath; - return Loc{humaneLoc.line, humaneLoc.column}; + return Loc::fromSourceLoc(sourceManager, loc, outFileName); } }; + +Loc Loc::fromSourceLoc(SourceManager* manager, SourceLoc loc, String* outFileName) +{ + auto humaneLoc = manager->getHumaneLoc(loc, SourceLocType::Actual); + if (outFileName) + *outFileName = humaneLoc.pathInfo.foundPath; + return Loc{humaneLoc.line, humaneLoc.column}; +} + struct PushNode { ASTLookupContext* context; @@ -46,12 +41,21 @@ struct PushNode ~PushNode() { if (context) context->nodePath.removeLast(); } }; -static Index _getDeclNameLength(Name* name) +static Index _getDeclNameLength(Name* name, Decl* optionalDecl = nullptr) { if (!name) return 0; if (name->text.startsWith("$")) + { + if (auto ctorDecl = as<ConstructorDecl>(optionalDecl)) + { + if (ctorDecl->parentDecl && optionalDecl->parentDecl->getName()) + { + return optionalDecl->parentDecl->getName()->text.getLength(); + } + } return 0; + } // HACK: our __subscript functions currently have a name "operator[]". // Since this isn't the name that actually appears in user's code, // we need to shorten its reported length to 1 for now. @@ -168,17 +172,31 @@ public: bool visitVarExpr(VarExpr* expr) { - if (expr->name && expr->declRef.getDecl() && - _isLocInRange(context, expr->loc, _getDeclNameLength(expr->name))) + if (expr->name && expr->declRef.getDecl()) { if (expr->declRef.getDecl()->hasModifier<ImplicitConversionModifier>()) return false; - ASTLookupResult result; - result.path = context->nodePath; - result.path.add(expr); - context->results.add(result); - return true; + Int declLength = 0; + if (auto ctorDecl = as<ConstructorDecl>(expr->declRef.getDecl())) + { + auto humaneLoc = context->sourceManager->getHumaneLoc(expr->loc, SourceLocType::Actual); + declLength = context->doc->getTokenLength(humaneLoc.line, humaneLoc.column); + } + else + { + declLength = _getDeclNameLength(expr->name, expr->declRef.getDecl()); + } + if (_isLocInRange( + context, expr->loc, declLength)) + { + ASTLookupResult result; + result.path = context->nodePath; + result.path.add(expr); + context->results.add(result); + return true; + } } + return dispatchIfNotNull(expr->originalExpr); } @@ -524,6 +542,21 @@ bool _findAstNodeImpl(ASTLookupContext& context, SyntaxNode* node) if (visitor.dispatchIfNotNull(typedefDecl->type.exp)) return true; } + for (auto modifier : decl->modifiers) + { + if (auto hlslSemantic = as<HLSLSemantic>(modifier)) + { + if (_isLocInRange( + &context, hlslSemantic->loc, hlslSemantic->name.getContentLength())) + { + ASTLookupResult result; + result.path = context.nodePath; + result.path.add(hlslSemantic); + context.results.add(result); + return true; + } + } + } if (auto container = as<ContainerDecl>(node)) { bool shouldInspectChildren = true; @@ -550,7 +583,7 @@ bool _findAstNodeImpl(ASTLookupContext& context, SyntaxNode* node) } List<ASTLookupResult> findASTNodesAt( - SourceManager* sourceManager, ModuleDecl* moduleDecl, ASTLookupType findType, UnownedStringSlice fileName, Int line, Int col) + DocumentVersion* doc, SourceManager* sourceManager, ModuleDecl* moduleDecl, ASTLookupType findType, UnownedStringSlice fileName, Int line, Int col) { ASTLookupContext context; context.sourceManager = sourceManager; @@ -559,6 +592,7 @@ List<ASTLookupResult> findASTNodesAt( context.cursorLoc = Loc{line, col}; context.findType = findType; context.sourceFileName = fileName; + context.doc = doc; _findAstNodeImpl(context, moduleDecl); return context.results; } diff --git a/source/slang/slang-language-server-ast-lookup.h b/source/slang/slang-language-server-ast-lookup.h index 9fad5e8bd..93e94a869 100644 --- a/source/slang/slang-language-server-ast-lookup.h +++ b/source/slang/slang-language-server-ast-lookup.h @@ -1,6 +1,7 @@ #pragma once #include "slang-ast-all.h" +#include "slang-workspace-version.h" namespace Slang { @@ -13,7 +14,23 @@ enum class ASTLookupType Decl, Invoke, }; + +struct Loc +{ + Int line; + Int col; + bool operator<(const Loc& other) + { + return line < other.line || line == other.line && col < other.col; + } + bool operator<=(const Loc& other) + { + return line < other.line || line == other.line && col <= other.col; + } + static Loc fromSourceLoc(SourceManager* manager, SourceLoc loc, String* outFileName = nullptr); +}; List<ASTLookupResult> findASTNodesAt( + DocumentVersion* doc, SourceManager* sourceManager, ModuleDecl* moduleDecl, ASTLookupType findType, diff --git a/source/slang/slang-language-server-collect-member.cpp b/source/slang/slang-language-server-collect-member.cpp deleted file mode 100644 index 73539b9d9..000000000 --- a/source/slang/slang-language-server-collect-member.cpp +++ /dev/null @@ -1,156 +0,0 @@ -// slang-language-server-collect-member.cpp - -// This file implements the logic to collect all members from a parsed type.] -// The flow is mostly the same as `lookupMemberInType`, but instead of looking for a specific name, -// we collect all members we see. - -#include "slang-language-server-collect-member.h" - -namespace Slang -{ -void collectMembersInType(MemberCollectingContext* context, Type* type) -{ - if (auto pointerLikeType = as<PointerLikeType>(type)) - { - collectMembersInType(context, pointerLikeType->elementType); - return; - } - - if (auto declRefType = as<DeclRefType>(type)) - { - auto declRef = declRefType->declRef; - - collectMembersInTypeDeclImpl( - context, - declRef); - } - else if (auto nsType = as<NamespaceType>(type)) - { - auto declRef = nsType->declRef; - - collectMembersInTypeDeclImpl(context, declRef); - } - else if (auto extractExistentialType = as<ExtractExistentialType>(type)) - { - // We want lookup to be performed on the underlying interface type of the existential, - // but we need to have a this-type substitution applied to ensure that the result of - // lookup will have a comparable substitution applied (allowing things like associated - // types, etc. used in the signature of a method to resolve correctly). - // - auto interfaceDeclRef = extractExistentialType->getSpecializedInterfaceDeclRef(); - collectMembersInTypeDeclImpl(context, interfaceDeclRef); - } - else if (auto thisType = as<ThisType>(type)) - { - auto interfaceType = DeclRefType::create(context->astBuilder, thisType->interfaceDeclRef); - collectMembersInType(context, interfaceType); - } - else if (auto andType = as<AndType>(type)) - { - auto leftType = andType->left; - auto rightType = andType->right; - collectMembersInType(context, leftType); - collectMembersInType(context, rightType); - } -} - -void collectMembersInTypeDeclImpl( - MemberCollectingContext* context, - DeclRef<Decl> declRef) -{ - if (declRef.getDecl()->checkState.getState() < DeclCheckState::ReadyForLookup) - return; - - if (auto genericTypeParamDeclRef = declRef.as<GenericTypeParamDecl>()) - { - // If the type we are doing lookup in is a generic type parameter, - // then the members it provides can only be discovered by looking - // at the constraints that are placed on that type. - auto genericDeclRef = genericTypeParamDeclRef.getParent().as<GenericDecl>(); - assert(genericDeclRef); - - for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef)) - { - if (constraintDeclRef.decl->checkState.getState() < DeclCheckState::ReadyForLookup) - { - continue; - } - - collectMembersInType( - context, - getSup(context->astBuilder, constraintDeclRef)); - } - } - else if (declRef.as<AssocTypeDecl>() || declRef.as<GlobalGenericParamDecl>()) - { - for (auto constraintDeclRef : - getMembersOfType<TypeConstraintDecl>(declRef.as<ContainerDecl>())) - { - if (constraintDeclRef.decl->checkState.getState() < DeclCheckState::ReadyForLookup) - { - continue; - } - collectMembersInType(context, getSup(context->astBuilder, constraintDeclRef)); - } - } - else if (auto namespaceDecl = declRef.as<NamespaceDecl>()) - { - for (auto member : namespaceDecl.getDecl()->members) - { - if (member->getName()) - { - context->members.add(member); - } - } - } - else if (auto aggTypeDeclBaseRef = declRef.as<AggTypeDeclBase>()) - { - // In this case we are peforming lookup in the context of an aggregate - // type or an `extension`, so the first thing to do is to look for - // matching members declared directly in the body of the type/`extension`. - // - for (auto member : aggTypeDeclBaseRef.getDecl()->members) - { - if (member->getName()) - { - context->members.add(member); - } - } - - if (auto aggTypeDeclRef = aggTypeDeclBaseRef.as<AggTypeDecl>()) - { - auto extensions = - context->semanticsContext.getCandidateExtensionsForTypeDecl(aggTypeDeclRef); - for (auto extDecl : extensions) - { - // TODO: check if the extension can be applied before including its members. - // TODO: eventually we need to insert a breadcrumb here so that - // the constructed result can somehow indicate that a member - // was found through an extension. - // - collectMembersInTypeDeclImpl( - context, - DeclRef<Decl>(extDecl, nullptr)); - } - } - - // For both aggregate types and their `extension`s, we want lookup to follow - // through the declared inheritance relationships on each declaration. - // - for (auto inheritanceDeclRef : getMembersOfType<InheritanceDecl>(aggTypeDeclBaseRef)) - { - // Some things that are syntactically `InheritanceDecl`s don't actually - // represent a subtype/supertype relationship, and thus we shouldn't - // include members from the base type when doing lookup in the - // derived type. - // - if (inheritanceDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>()) - continue; - - collectMembersInType( - context, getSup(context->astBuilder, inheritanceDeclRef)); - } - } -} - -} // namespace Slang diff --git a/source/slang/slang-language-server-collect-member.h b/source/slang/slang-language-server-collect-member.h deleted file mode 100644 index bb48c4d5b..000000000 --- a/source/slang/slang-language-server-collect-member.h +++ /dev/null @@ -1,25 +0,0 @@ -// slang-language-server-collect-member.h -#pragma once - -#include "slang-ast-all.h" -#include "slang-syntax.h" -#include "slang-check-impl.h" - -namespace Slang -{ - -struct MemberCollectingContext -{ - ASTBuilder* astBuilder; - List<Decl*> members; - SharedSemanticsContext semanticsContext; - MemberCollectingContext(Linkage* linkage, Module* module, DiagnosticSink* sink) - : semanticsContext(linkage, module, sink) - {} -}; - -void collectMembersInTypeDeclImpl(MemberCollectingContext* context, DeclRef<Decl> declRef); - -void collectMembersInType(MemberCollectingContext* context, Type* type); - -} // namespace Slang diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp new file mode 100644 index 000000000..c47e951fb --- /dev/null +++ b/source/slang/slang-language-server-completion.cpp @@ -0,0 +1,472 @@ +// slang-language-server-completion.cpp + +#include "slang-language-server-completion.h" +#include "slang-language-server-ast-lookup.h" +#include "slang-language-server.h" + +#include "slang-ast-all.h" +#include "slang-syntax.h" +#include "slang-check-impl.h" + +namespace Slang +{ + +static bool _isIdentifierChar(char ch) +{ + return ch >= '0' && ch <= '9' || ch >= 'a' && ch <= 'z' || ch >= 'A' && ch <= 'Z' || ch == '_'; +} + +static bool _isWhitespaceChar(char ch) { return ch == ' ' || ch == '\r' || ch == '\n' || ch == '\t'; } + +static const char* hlslSemanticNames[] = { + "register", + "packoffset", + "read", + "write", + "SV_ClipDistance", + "SV_CullDistance", + "SV_Coverage", + "SV_Depth", + "SV_DepthGreaterEqual", + "SV_DepthLessEqual", + "SV_DispatchThreadID", + "SV_DomainLocation", + "SV_GroupID", + "SV_GroupIndex", + "SV_GroupThreadID", + "SV_GSInstanceID", + "SV_InnerCoverage", + "SV_InsideTessFactor", + "SV_InstanceID", + "SV_IsFrontFace", + "SV_OutputControlPointID", + "SV_Position", + "SV_PrimitiveID", + "SV_RenderTargetArrayIndex", + "SV_SampleIndex", + "SV_StencilRef", + "SV_Target", + "SV_TessFactor", + "SV_VertexID", + "SV_ViewportArrayIndex", + "SV_ShadingRate", +}; + +SlangResult CompletionContext::tryCompleteHLSLSemantic() +{ + auto findResult = findASTNodesAt( + doc, + version->linkage->getSourceManager(), + parsedModule->getModuleDecl(), + ASTLookupType::Decl, + canonicalPath, + line, + col); + if (findResult.getCount() == 1 && findResult[0].path.getCount() != 0) + { + if (auto semantic = as<HLSLSemantic>(findResult[0].path.getLast())) + { + List<LanguageServerProtocol::CompletionItem> items; + for (auto name : hlslSemanticNames) + { + LanguageServerProtocol::CompletionItem item; + item.label = name; + item.kind = LanguageServerProtocol::kCompletionItemKindKeyword; + for (auto ch : getCommitChars()) + item.commitCharacters.add(ch); + items.add(item); + } + server->m_connection->sendResult(&items, responseId); + return SLANG_OK; + } + } + return SLANG_FAIL; +} + +SlangResult CompletionContext::tryCompleteMember() +{ + // Scan backward until we locate a '.' or ':'. + if (cursorOffset > 0) + cursorOffset--; + while (cursorOffset > 0 && _isWhitespaceChar(doc->getText()[cursorOffset])) + { + cursorOffset--; + } + while (cursorOffset > 0 && _isIdentifierChar(doc->getText()[cursorOffset])) + { + cursorOffset--; + } + while (cursorOffset > 0 && _isWhitespaceChar(doc->getText()[cursorOffset])) + { + cursorOffset--; + } + if (cursorOffset > 0 && doc->getText()[cursorOffset] == ':') + cursorOffset--; + if (cursorOffset <= 0 || + (doc->getText()[cursorOffset] != '.' && doc->getText()[cursorOffset] != ':')) + { + return SLANG_FAIL; + } + doc->offsetToLineCol(cursorOffset, line, col); + auto findResult = findASTNodesAt( + doc, + version->linkage->getSourceManager(), + parsedModule->getModuleDecl(), + ASTLookupType::Decl, + canonicalPath, + line, + col); + if (findResult.getCount() != 1) + { + return SLANG_FAIL; + } + if (findResult[0].path.getCount() == 0) + { + return SLANG_FAIL; + } + Expr* baseExpr = nullptr; + if (auto memberExpr = as<MemberExpr>(findResult[0].path.getLast())) + { + baseExpr = memberExpr->baseExpression; + } + else if (auto staticMemberExpr = as<StaticMemberExpr>(findResult[0].path.getLast())) + { + baseExpr = staticMemberExpr->baseExpression; + } + else if (auto swizzleExpr = as<SwizzleExpr>(findResult[0].path.getLast())) + { + baseExpr = swizzleExpr->base; + } + else if (auto matSwizzleExpr = as<MatrixSwizzleExpr>(findResult[0].path.getLast())) + { + baseExpr = matSwizzleExpr->base; + } + if (!baseExpr || !baseExpr->type.type || + baseExpr->type.type->equals(version->linkage->getASTBuilder()->getErrorType())) + { + return SLANG_FAIL; + } + + List<LanguageServerProtocol::CompletionItem> items = collectMembers(baseExpr); + server->m_connection->sendResult(&items, responseId); + return SLANG_OK; +} + +// The following collectMember* functions implement the logic to collect all members from a parsed type.] +// The flow is mostly the same as `lookupMemberInType`, but instead of looking for a specific name, +// we collect all members we see. + +struct MemberCollectingContext +{ + ASTBuilder* astBuilder; + List<Decl*> members; + bool includeInstanceMembers = true; + SharedSemanticsContext semanticsContext; + MemberCollectingContext(Linkage* linkage, Module* module, DiagnosticSink* sink) + : semanticsContext(linkage, module, sink) + {} +}; + +void collectMembersInTypeDeclImpl(MemberCollectingContext* context, DeclRef<Decl> declRef); + +void collectMembersInType(MemberCollectingContext* context, Type* type); + +void collectMembersInType(MemberCollectingContext* context, Type* type) +{ + if (auto pointerLikeType = as<PointerLikeType>(type)) + { + collectMembersInType(context, pointerLikeType->elementType); + return; + } + + if (auto declRefType = as<DeclRefType>(type)) + { + auto declRef = declRefType->declRef; + + collectMembersInTypeDeclImpl( + context, + declRef); + } + else if (auto nsType = as<NamespaceType>(type)) + { + auto declRef = nsType->declRef; + + collectMembersInTypeDeclImpl(context, declRef); + } + else if (auto extractExistentialType = as<ExtractExistentialType>(type)) + { + // We want lookup to be performed on the underlying interface type of the existential, + // but we need to have a this-type substitution applied to ensure that the result of + // lookup will have a comparable substitution applied (allowing things like associated + // types, etc. used in the signature of a method to resolve correctly). + // + auto interfaceDeclRef = extractExistentialType->getSpecializedInterfaceDeclRef(); + collectMembersInTypeDeclImpl(context, interfaceDeclRef); + } + else if (auto thisType = as<ThisType>(type)) + { + auto interfaceType = DeclRefType::create(context->astBuilder, thisType->interfaceDeclRef); + collectMembersInType(context, interfaceType); + } + else if (auto andType = as<AndType>(type)) + { + auto leftType = andType->left; + auto rightType = andType->right; + collectMembersInType(context, leftType); + collectMembersInType(context, rightType); + } +} + +void collectMembersInTypeDeclImpl( + MemberCollectingContext* context, + DeclRef<Decl> declRef) +{ + if (declRef.getDecl()->checkState.getState() < DeclCheckState::ReadyForLookup) + return; + + if (auto genericTypeParamDeclRef = declRef.as<GenericTypeParamDecl>()) + { + // If the type we are doing lookup in is a generic type parameter, + // then the members it provides can only be discovered by looking + // at the constraints that are placed on that type. + auto genericDeclRef = genericTypeParamDeclRef.getParent().as<GenericDecl>(); + assert(genericDeclRef); + + for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef)) + { + if (constraintDeclRef.decl->checkState.getState() < DeclCheckState::ReadyForLookup) + { + continue; + } + + collectMembersInType( + context, + getSup(context->astBuilder, constraintDeclRef)); + } + } + else if (declRef.as<AssocTypeDecl>() || declRef.as<GlobalGenericParamDecl>()) + { + for (auto constraintDeclRef : + getMembersOfType<TypeConstraintDecl>(declRef.as<ContainerDecl>())) + { + if (constraintDeclRef.decl->checkState.getState() < DeclCheckState::ReadyForLookup) + { + continue; + } + collectMembersInType(context, getSup(context->astBuilder, constraintDeclRef)); + } + } + else if (auto namespaceDecl = declRef.as<NamespaceDecl>()) + { + for (auto member : namespaceDecl.getDecl()->members) + { + if (member->getName()) + { + context->members.add(member); + } + } + } + else if (auto aggTypeDeclBaseRef = declRef.as<AggTypeDeclBase>()) + { + // In this case we are peforming lookup in the context of an aggregate + // type or an `extension`, so the first thing to do is to look for + // matching members declared directly in the body of the type/`extension`. + // + for (auto member : aggTypeDeclBaseRef.getDecl()->members) + { + if (member->getName()) + { + if (!context->includeInstanceMembers) + { + // Skip non-static members. + if (as<PropertyDecl>(member)) + continue; + if (as<SubscriptDecl>(member)) + continue; + if (as<VarDeclBase>(member) || as<FuncDecl>(member)) + { + if (!member->findModifier<HLSLStaticModifier>()) + { + continue; + } + } + } + context->members.add(member); + } + } + + if (auto aggTypeDeclRef = aggTypeDeclBaseRef.as<AggTypeDecl>()) + { + auto extensions = + context->semanticsContext.getCandidateExtensionsForTypeDecl(aggTypeDeclRef); + for (auto extDecl : extensions) + { + // TODO: check if the extension can be applied before including its members. + // TODO: eventually we need to insert a breadcrumb here so that + // the constructed result can somehow indicate that a member + // was found through an extension. + // + collectMembersInTypeDeclImpl( + context, + DeclRef<Decl>(extDecl, nullptr)); + } + } + + // For both aggregate types and their `extension`s, we want lookup to follow + // through the declared inheritance relationships on each declaration. + // + for (auto inheritanceDeclRef : getMembersOfType<InheritanceDecl>(aggTypeDeclBaseRef)) + { + // Some things that are syntactically `InheritanceDecl`s don't actually + // represent a subtype/supertype relationship, and thus we shouldn't + // include members from the base type when doing lookup in the + // derived type. + // + if (inheritanceDeclRef.getDecl()->hasModifier<IgnoreForLookupModifier>()) + continue; + + collectMembersInType( + context, getSup(context->astBuilder, inheritanceDeclRef)); + } + } +} + +List<LanguageServerProtocol::CompletionItem> CompletionContext::collectMembers(Expr* baseExpr) +{ + List<LanguageServerProtocol::CompletionItem> result; + auto linkage = version->linkage; + Type* type = baseExpr->type.type; + bool isInstance = true; + if (auto typeType = as<TypeType>(type)) + { + type = typeType->type; + isInstance = false; + } + version->currentCompletionItems.clear(); + if (type) + { + if (isInstance && as<ArithmeticExpressionType>(type)) + { + // Hard code members for vector and matrix types. + result.clear(); + version->currentCompletionItems.clear(); + int elementCount = 0; + Type* elementType = nullptr; + const char* memberNames[4] = {"x", "y", "z", "w"}; + if (auto vectorType = as<VectorExpressionType>(type)) + { + if (auto elementCountVal = as<ConstantIntVal>(vectorType->elementCount)) + { + elementCount = (int)elementCountVal->value; + elementType = vectorType->elementType; + } + } + else if (auto matrixType = as<MatrixExpressionType>(type)) + { + if (auto elementCountVal = as<ConstantIntVal>(matrixType->getRowCount())) + { + elementCount = (int)elementCountVal->value; + elementType = matrixType->getRowType(); + } + } + String typeStr; + if (elementType) + typeStr = elementType->toString(); + elementCount = Math::Min(elementCount, 4); + for (int i = 0; i < elementCount; i++) + { + LanguageServerProtocol::CompletionItem item; + item.data = 0; + item.detail = typeStr; + item.kind = LanguageServerProtocol::kCompletionItemKindVariable; + item.label = memberNames[i]; + result.add(item); + } + } + else + { + DiagnosticSink sink; + MemberCollectingContext context(linkage, parsedModule, &sink); + context.astBuilder = linkage->getASTBuilder(); + context.includeInstanceMembers = isInstance; + collectMembersInType(&context, type); + HashSet<String> deduplicateSet; + for (auto member : context.members) + { + LanguageServerProtocol::CompletionItem item; + item.label = member->getName()->text; + item.kind = 0; + if (as<TypeConstraintDecl>(member)) + { + continue; + } + if (as<ConstructorDecl>(member)) + { + continue; + } + if (as<SubscriptDecl>(member)) + { + continue; + } + + if (item.label.startsWith("$")) + continue; + if (!deduplicateSet.Add(item.label)) + continue; + + if (as<StructDecl>(member)) + { + item.kind = LanguageServerProtocol::kCompletionItemKindStruct; + } + else if (as<ClassDecl>(member)) + { + item.kind = LanguageServerProtocol::kCompletionItemKindClass; + } + else if (as<InterfaceDecl>(member)) + { + item.kind = LanguageServerProtocol::kCompletionItemKindInterface; + } + else if (as<SimpleTypeDecl>(member)) + { + item.kind = LanguageServerProtocol::kCompletionItemKindClass; + } + else if (as<PropertyDecl>(member)) + { + item.kind = LanguageServerProtocol::kCompletionItemKindProperty; + } + else if (as<EnumDecl>(member)) + { + item.kind = LanguageServerProtocol::kCompletionItemKindEnum; + } + else if (as<VarDeclBase>(member)) + { + item.kind = LanguageServerProtocol::kCompletionItemKindVariable; + } + else if (as<EnumCaseDecl>(member)) + { + item.kind = LanguageServerProtocol::kCompletionItemKindEnumMember; + } + else if (as<CallableDecl>(member)) + { + item.kind = LanguageServerProtocol::kCompletionItemKindMethod; + } + else if (as<AssocTypeDecl>(member)) + { + item.kind = LanguageServerProtocol::kCompletionItemKindClass; + } + item.data = String(version->currentCompletionItems.getCount()); + result.add(item); + version->currentCompletionItems.add(member); + } + } + + for (auto& item : result) + { + for (auto ch : getCommitChars()) + item.commitCharacters.add(ch); + } + } + return result; +} + +} // namespace Slang diff --git a/source/slang/slang-language-server-completion.h b/source/slang/slang-language-server-completion.h new file mode 100644 index 000000000..73aad6cd1 --- /dev/null +++ b/source/slang/slang-language-server-completion.h @@ -0,0 +1,27 @@ +// slang-language-server-completion.h +#pragma once + +#include "slang-workspace-version.h" + +namespace Slang +{ +class LanguageServer; + +struct CompletionContext +{ + LanguageServer* server; + Index cursorOffset; + WorkspaceVersion* version; + DocumentVersion* doc; + Module* parsedModule; + JSONValue responseId; + UnownedStringSlice canonicalPath; + Int line; + Int col; + + SlangResult tryCompleteMember(); + SlangResult tryCompleteHLSLSemantic(); + List<LanguageServerProtocol::CompletionItem> collectMembers(Expr* baseExpr); +}; + +} // namespace Slang diff --git a/source/slang/slang-language-server-semantic-tokens.cpp b/source/slang/slang-language-server-semantic-tokens.cpp index 34c40d8fd..806fa69ba 100644 --- a/source/slang/slang-language-server-semantic-tokens.cpp +++ b/source/slang/slang-language-server-semantic-tokens.cpp @@ -1,6 +1,7 @@ #include "slang-language-server-semantic-tokens.h" #include "slang-visitor.h" #include "slang-ast-support-types.h" +#include "../core/slang-char-util.h" #include <algorithm> namespace Slang @@ -412,7 +413,7 @@ void iterateAST(UnownedStringSlice fileName, SourceManager* manager, SyntaxNode* } const char* kSemanticTokenTypes[] = { - "type", "enumMember", "variable", "parameter", "function", "property", "namespace"}; + "type", "enumMember", "variable", "parameter", "function", "property", "namespace", "keyword" }; static_assert(SLANG_COUNT_OF(kSemanticTokenTypes) == (int)SemanticTokenType::NormalText, "kSemanticTokenTypes must match SemanticTokenType"); @@ -420,15 +421,15 @@ SemanticToken _createSemanticToken(SourceManager* manager, SourceLoc loc, Name* { SemanticToken token; auto humaneLoc = manager->getHumaneLoc(loc, SourceLocType::Actual); - token.line = (int)(humaneLoc.line - 1); - token.col = (int)(humaneLoc.column - 1); + token.line = (int)(humaneLoc.line); + token.col = (int)(humaneLoc.column); token.length = name ? (int)(name->text.getLength()) : 0; token.type = SemanticTokenType::NormalText; return token; } -List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedStringSlice fileName) +List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedStringSlice fileName, DocumentVersion* doc) { auto manager = linkage->getSourceManager(); @@ -439,7 +440,6 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS token.type != SemanticTokenType::NormalText) result.add(token); }; - iterateAST( fileName, manager, @@ -465,6 +465,11 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS return; token.type = SemanticTokenType::Type; } + else if (as<ConstructorDecl>(target)) + { + token.type = SemanticTokenType::Type; + token.length = doc->getTokenLength(token.line, token.col); + } else if (as<SimpleTypeDecl>(target)) { token.type = SemanticTokenType::Type; @@ -479,6 +484,11 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS } else if (as<VarDecl>(target)) { + if (as<MemberExpr>(declRef->originalExpr) || + as<StaticMemberExpr>(declRef->originalExpr)) + { + return; + } token.type = SemanticTokenType::Variable; } else if (as<FunctionDeclBase>(target)) @@ -503,6 +513,13 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS } } + else if (auto accessorDecl = as<AccessorDecl>(node)) + { + SemanticToken token = _createSemanticToken( + manager, accessorDecl->loc, accessorDecl->getName()); + token.type = SemanticTokenType::Keyword; + maybeInsertToken(token); + } else if (auto typeDecl = as<SimpleTypeDecl>(node)) { if (typeDecl->getName()) diff --git a/source/slang/slang-language-server-semantic-tokens.h b/source/slang/slang-language-server-semantic-tokens.h index 925972200..c7cd8b63a 100644 --- a/source/slang/slang-language-server-semantic-tokens.h +++ b/source/slang/slang-language-server-semantic-tokens.h @@ -5,12 +5,13 @@ #include "slang-ast-all.h" #include "slang-syntax.h" #include "slang-compiler.h" +#include "slang-workspace-version.h" namespace Slang { enum class SemanticTokenType { - Type, EnumMember, Variable, Parameter, Function, Property, Namespace, NormalText + Type, EnumMember, Variable, Parameter, Function, Property, Namespace, Keyword, NormalText }; extern const char* kSemanticTokenTypes[(int)SemanticTokenType::NormalText]; @@ -30,7 +31,7 @@ struct SemanticToken } }; List<SemanticToken> getSemanticTokens( - Linkage* linkage, Module* module, UnownedStringSlice fileName); + Linkage* linkage, Module* module, UnownedStringSlice fileName, DocumentVersion* doc); List<uint32_t> getEncodedTokens(List<SemanticToken>& tokens); } // namespace Slang diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index d81e7dd8b..9f7329e2c 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -12,61 +12,29 @@ #include "../core/slang-secure-crt.h" #include "../core/slang-range.h" #include "../../slang-com-helper.h" +#include "../compiler-core/slang-json-native.h" #include "../compiler-core/slang-json-rpc-connection.h" #include "../compiler-core/slang-language-server-protocol.h" #include "slang-language-server.h" #include "slang-workspace-version.h" #include "slang-language-server-ast-lookup.h" -#include "slang-language-server-collect-member.h" +#include "slang-language-server-completion.h" #include "slang-language-server-semantic-tokens.h" #include "slang-ast-print.h" #include "slang-doc-markdown-writer.h" +#include "../../tools/platform/performance-counter.h" namespace Slang { using namespace LanguageServerProtocol; -class LanguageServer +ArrayView<const char*> getCommitChars() { -public: - RefPtr<JSONRPCConnection> m_connection; - ComPtr<slang::IGlobalSession> m_session; - RefPtr<Workspace> m_workspace; - Dictionary<String, String> m_lastPublishedDiagnostics; - time_t m_lastDiagnosticUpdateTime = 0; - - bool m_quit = false; - List<LanguageServerProtocol::WorkspaceFolder> m_workspaceFolders; - - SlangResult init(const LanguageServerProtocol::InitializeParams& args); - SlangResult execute(); - void update(); - SlangResult didOpenTextDocument(const LanguageServerProtocol::DidOpenTextDocumentParams& args); - SlangResult didCloseTextDocument( - const LanguageServerProtocol::DidCloseTextDocumentParams& args); - SlangResult didChangeTextDocument( - const LanguageServerProtocol::DidChangeTextDocumentParams& args); - SlangResult hover(const LanguageServerProtocol::HoverParams& args, const JSONValue& responseId); - SlangResult gotoDefinition(const LanguageServerProtocol::DefinitionParams& args, const JSONValue& responseId); - SlangResult completion( - const LanguageServerProtocol::CompletionParams& args, const JSONValue& responseId); - SlangResult completionResolve( - const LanguageServerProtocol::CompletionItem& args, const JSONValue& responseId); - SlangResult semanticTokens( - const LanguageServerProtocol::SemanticTokensParams& args, const JSONValue& responseId); - SlangResult signatureHelp( - const LanguageServerProtocol::SignatureHelpParams& args, const JSONValue& responseId); - - List<LanguageServerProtocol::CompletionItem> collectMembers( - WorkspaceVersion* wsVersion, Module* module, Expr* baseExpr); - -private: - SlangResult _executeSingle(); - slang::IGlobalSession* getOrCreateGlobalSession(); - void resetDiagnosticUpdateTime(); - void publishDiagnostics(); -}; - + static const char* _commitCharsArray[] = {",", ".", ";", ":", "(", ")", "[", "]", + "<", ">", "{", "}", "*", "&", "^", "%", + "!", "-", "=", "+", "|", "/", "?", " "}; + return makeArrayView(_commitCharsArray, SLANG_COUNT_OF(_commitCharsArray)); +} SlangResult LanguageServer::init(const InitializeParams& args) { @@ -105,14 +73,8 @@ String uriToCanonicalPath(const String& uri) return canonnicalPath; } -SlangResult LanguageServer::_executeSingle() +SlangResult LanguageServer::parseNextMessage() { - // If we don't have a message, we can quit for now - if (!m_connection->hasMessage()) - { - return SLANG_OK; - } - const JSONRPCMessageType msgType = m_connection->getMessageType(); switch (msgType) @@ -121,8 +83,6 @@ SlangResult LanguageServer::_executeSingle() { JSONRPCCall call; SLANG_RETURN_ON_FAIL(m_connection->getRPCOrSendError(&call)); - - // Do different things if (call.method == ExitParams::methodName) { m_quit = true; @@ -143,15 +103,14 @@ SlangResult LanguageServer::_executeSingle() InitializeResult result; result.serverInfo.name = "SlangLanguageServer"; result.serverInfo.version = "1.0"; - result.capabilities.positionEncoding = "utf-8"; + result.capabilities.positionEncoding = "utf-16"; result.capabilities.textDocumentSync.openClose = true; - result.capabilities.textDocumentSync.change = (int)TextDocumentSyncKind::Full; + result.capabilities.textDocumentSync.change = (int)TextDocumentSyncKind::Incremental; + result.capabilities.workspace.workspaceFolders.supported = true; + result.capabilities.workspace.workspaceFolders.changeNotifications = false; result.capabilities.hoverProvider = true; result.capabilities.definitionProvider = true; - const char* commitChars[] = {",", ".", ";", ":", "(", ")", "[", "]", - "<", ">", "{", "}", "*", "&", "^", "%", - "!", "-", "=", "+", "|", "/", "?"}; - for (auto ch : commitChars) + for (auto ch : getCommitChars()) result.capabilities.completionProvider.allCommitCharacters.add(ch); result.capabilities.completionProvider.triggerCharacters.add("."); result.capabilities.completionProvider.triggerCharacters.add(":"); @@ -166,84 +125,49 @@ SlangResult LanguageServer::_executeSingle() m_connection->sendResult(&result, call.id); return SLANG_OK; } - else if (call.method == DidOpenTextDocumentParams::methodName) - { - DidOpenTextDocumentParams args; - SLANG_RETURN_ON_FAIL( - m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); - return didOpenTextDocument(args); - } - else if (call.method == DidCloseTextDocumentParams::methodName) - { - DidCloseTextDocumentParams args; - SLANG_RETURN_ON_FAIL( - m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); - return didCloseTextDocument(args); - } - else if (call.method == DidChangeTextDocumentParams::methodName) - { - DidChangeTextDocumentParams args; - SLANG_RETURN_ON_FAIL( - m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); - return didChangeTextDocument(args); - } - else if (call.method == HoverParams::methodName) - { - HoverParams args; - SLANG_RETURN_ON_FAIL( - m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); - return hover(args, call.id); - } - else if (call.method == DefinitionParams::methodName) - { - DefinitionParams args; - SLANG_RETURN_ON_FAIL( - m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); - return gotoDefinition(args, call.id); - } - else if (call.method == CompletionParams::methodName) - { - CompletionParams args; - SLANG_RETURN_ON_FAIL( - m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); - return completion(args, call.id); - } - else if (call.method == SemanticTokensParams::methodName) - { - SemanticTokensParams args; - SLANG_RETURN_ON_FAIL( - m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); - return semanticTokens(args, call.id); - } - else if (call.method == SignatureHelpParams::methodName) - { - SignatureHelpParams args; - SLANG_RETURN_ON_FAIL( - m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); - return signatureHelp(args, call.id); - } - else if (call.method == "completionItem/resolve") - { - CompletionItem args; - SLANG_RETURN_ON_FAIL( - m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); - return completionResolve(args, call.id); - - } else if (call.method == "initialized") { + sendConfigRequest(); + registerCapability("workspace/didChangeConfiguration"); + m_initialized = true; return SLANG_OK; } - else if (call.method.startsWith("$/")) + else { - // Ignore. + queueJSONCall(call); return SLANG_OK; } - else + + } + case JSONRPCMessageType::Result: + { + JSONResultResponse response; + SLANG_RETURN_ON_FAIL(m_connection->getRPCOrSendError(&response)); + auto responseId = (int)m_connection->getContainer()->asInteger(response.id); + switch (responseId) { - return m_connection->sendError(JSONRPC::ErrorCode::MethodNotFound, call.id); + case kConfigResponseId: + if (response.result.getKind() == JSONValue::Kind::Array) + { + auto arr = m_connection->getContainer()->getArray(response.result); + if (arr.getCount() > 0) + { + updatePredefinedMacros(arr[0]); + } + } + break; } + return SLANG_OK; } + case JSONRPCMessageType::Error: + { +#if 0 // Enable for debug only + JSONRPCErrorResponse error; + SLANG_RETURN_ON_FAIL(m_connection->getRPCOrSendError(&error)); +#endif + return SLANG_OK; + } + break; default: { return m_connection->sendError( @@ -268,6 +192,14 @@ String getDeclSignatureString(DeclRef<Decl> declRef, ASTBuilder* astBuilder) ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | ASTPrinter::OptionFlag::SimplifiedBuiltinType); printer.addDeclSignature(declRef); + if (auto varDecl = as<VarDeclBase>(declRef.getDecl())) + { + auto& sb = printer.getStringBuilder(); + if (auto litExpr = as<LiteralExpr>(varDecl->initExpr)) + { + sb << " = " << litExpr->token.getContent(); + } + } return printer.getString(); } return "unknown"; @@ -332,6 +264,15 @@ SlangResult LanguageServer::hover( const LanguageServerProtocol::HoverParams& args, const JSONValue& responseId) { String canonicalPath = uriToCanonicalPath(args.textDocument.uri); + RefPtr<DocumentVersion> doc; + if (!m_workspace->openedDocuments.TryGetValue(canonicalPath, doc)) + { + m_connection->sendResult(NullResponse::get(), responseId); + return SLANG_OK; + } + Index line, col; + doc->zeroBasedUTF16LocToOneBasedUTF8Loc(args.position.line, args.position.character, line, col); + auto version = m_workspace->getCurrentVersion(); Module* parsedModule = version->getOrLoadModule(canonicalPath); if (!parsedModule) @@ -340,12 +281,13 @@ SlangResult LanguageServer::hover( return SLANG_OK; } auto findResult = findASTNodesAt( + doc.Ptr(), version->linkage->getSourceManager(), parsedModule->getModuleDecl(), ASTLookupType::Decl, canonicalPath.getUnownedSlice(), - args.position.line + 1, - args.position.character + 1); + line, + col); if (findResult.getCount() == 0 || findResult[0].path.getCount() == 0) { m_connection->sendResult(NullResponse::get(), responseId); @@ -375,10 +317,18 @@ SlangResult LanguageServer::hover( hover.range.start.line = int(nodeHumaneLoc.line - 1); hover.range.end.line = int(nodeHumaneLoc.line - 1); hover.range.start.character = int(nodeHumaneLoc.column - 1); - if (declRef.getName()) + auto name = declRef.getName(); + if (auto ctorDecl = declRef.as<ConstructorDecl>()) + { + auto parent = ctorDecl.getDecl()->parentDecl; + if (parent) + { + name = parent->getName(); + } + } + if (name) { - hover.range.end.character = - int(nodeHumaneLoc.column + declRef.getName()->text.getLength() - 1); + hover.range.end.character = int(nodeHumaneLoc.column + name->text.getLength() - 1); } } }; @@ -413,6 +363,15 @@ SlangResult LanguageServer::gotoDefinition( const LanguageServerProtocol::DefinitionParams& args, const JSONValue& responseId) { String canonicalPath = uriToCanonicalPath(args.textDocument.uri); + RefPtr<DocumentVersion> doc; + if (!m_workspace->openedDocuments.TryGetValue(canonicalPath, doc)) + { + m_connection->sendResult(NullResponse::get(), responseId); + return SLANG_OK; + } + Index line, col; + doc->zeroBasedUTF16LocToOneBasedUTF8Loc(args.position.line, args.position.character, line, col); + auto version = m_workspace->getCurrentVersion(); Module* parsedModule = version->getOrLoadModule(canonicalPath); if (!parsedModule) @@ -421,12 +380,13 @@ SlangResult LanguageServer::gotoDefinition( return SLANG_OK; } auto findResult = findASTNodesAt( + doc.Ptr(), version->linkage->getSourceManager(), parsedModule->getModuleDecl(), ASTLookupType::Decl, canonicalPath.getUnownedSlice(), - args.position.line + 1, - args.position.character + 1); + line, + col); if (findResult.getCount() == 0 || findResult[0].path.getCount() == 0) { m_connection->sendResult(NullResponse::get(), responseId); @@ -444,7 +404,9 @@ SlangResult LanguageServer::gotoDefinition( if (declRefExpr->declRef.getDecl()) { auto location = version->linkage->getSourceManager()->getHumaneLoc( - declRefExpr->declRef.getNameLoc(), SourceLocType::Actual); + declRefExpr->declRef.getNameLoc().isValid() ? declRefExpr->declRef.getNameLoc() + : declRefExpr->declRef.getLoc(), + SourceLocType::Actual); auto name = declRefExpr->declRef.getName(); locations.add(LocationResult{location, name ? (int)name->text.getLength() : 0}); } @@ -485,28 +447,22 @@ SlangResult LanguageServer::gotoDefinition( for (auto loc : locations) { Location result; - result.uri = URI::fromLocalFilePath(loc.loc.pathInfo.foundPath.getUnownedSlice()).uri; - result.range.start.line = int(loc.loc.line - 1); - result.range.start.character = int(loc.loc.column - 1); - result.range.end = result.range.start; - result.range.end.character += loc.length; - results.add(result); + if (File::exists(loc.loc.pathInfo.foundPath)) + { + result.uri = + URI::fromLocalFilePath(loc.loc.pathInfo.foundPath.getUnownedSlice()).uri; + result.range.start.line = int(loc.loc.line - 1); + result.range.start.character = int(loc.loc.column - 1); + result.range.end = result.range.start; + result.range.end.character += loc.length; + results.add(result); + } } m_connection->sendResult(&results, responseId); return SLANG_OK; } } -bool _isIdentifierChar(char ch) -{ - return ch >= '0' && ch <= '9' || ch >= 'a' && ch <= 'z' || ch >= 'A' && ch <= 'Z' || ch == '_'; -} - -bool _isWhitespaceChar(char ch) -{ - return ch == ' ' || ch == '\r' || ch == '\n' || ch == '\t'; -} - SlangResult LanguageServer::completion( const LanguageServerProtocol::CompletionParams& args, const JSONValue& responseId) { @@ -518,39 +474,16 @@ SlangResult LanguageServer::completion( m_connection->sendResult(NullResponse::get(), responseId); return SLANG_OK; } - - auto cursorOffset = doc->getOffset(args.position.line + 1, args.position.character + 1); + Index utf8Line, utf8Col; + doc->zeroBasedUTF16LocToOneBasedUTF8Loc( + args.position.line, args.position.character, utf8Line, utf8Col); + auto cursorOffset = doc->getOffset(utf8Line, utf8Col); if (cursorOffset == -1 || doc->getText().getLength() == 0) { m_connection->sendResult(NullResponse::get(), responseId); return SLANG_OK; } - // Scan backward until we locate a '.' or ':'. - if (cursorOffset == doc->getText().getLength()) - cursorOffset--; - while (cursorOffset > 0 && _isWhitespaceChar(doc->getText()[cursorOffset])) - { - cursorOffset--; - } - while (cursorOffset > 0 && _isIdentifierChar(doc->getText()[cursorOffset])) - { - cursorOffset--; - } - while (cursorOffset > 0 && _isWhitespaceChar(doc->getText()[cursorOffset])) - { - cursorOffset--; - } - if (cursorOffset > 0 && doc->getText()[cursorOffset] == ':') - cursorOffset--; - if (cursorOffset <= 0 || - (doc->getText()[cursorOffset] != '.' && doc->getText()[cursorOffset] != ':')) - { - m_connection->sendResult(NullResponse::get(), responseId); - return SLANG_OK; - } - Index line = 0; - Index col = 0; - doc->offsetToLineCol(cursorOffset, line, col); + auto version = m_workspace->getCurrentVersion(); Module* parsedModule = version->getOrLoadModule(canonicalPath); if (!parsedModule) @@ -558,48 +491,26 @@ SlangResult LanguageServer::completion( m_connection->sendResult(NullResponse::get(), responseId); return SLANG_OK; } - auto findResult = findASTNodesAt( - version->linkage->getSourceManager(), - parsedModule->getModuleDecl(), - ASTLookupType::Decl, - canonicalPath.getUnownedSlice(), - line, - col); - if (findResult.getCount() != 1) + + CompletionContext context; + context.server = this; + context.cursorOffset = cursorOffset; + context.version = version; + context.doc = doc.Ptr(); + context.parsedModule = parsedModule; + context.responseId = responseId; + context.canonicalPath = canonicalPath.getUnownedSlice(); + context.line = utf8Line; + context.col = utf8Col; + if (SLANG_SUCCEEDED(context.tryCompleteMember())) { - m_connection->sendResult(NullResponse::get(), responseId); return SLANG_OK; } - if (findResult[0].path.getCount() == 0) + if (SLANG_SUCCEEDED(context.tryCompleteHLSLSemantic())) { - m_connection->sendResult(NullResponse::get(), responseId); return SLANG_OK; } - Expr* baseExpr = nullptr; - if (auto memberExpr = as<MemberExpr>(findResult[0].path.getLast())) - { - baseExpr = memberExpr->baseExpression; - } - else if (auto staticMemberExpr = as<StaticMemberExpr>(findResult[0].path.getLast())) - { - baseExpr = staticMemberExpr->baseExpression; - } - else if (auto swizzleExpr = as<SwizzleExpr>(findResult[0].path.getLast())) - { - baseExpr = swizzleExpr->base; - } - else if (auto matSwizzleExpr = as<MatrixSwizzleExpr>(findResult[0].path.getLast())) - { - baseExpr = matSwizzleExpr->base; - } - if (!baseExpr || !baseExpr->type.type || baseExpr->type.type->equals(version->linkage->getASTBuilder()->getErrorType())) - { - m_connection->sendResult(NullResponse::get(), responseId); - return SLANG_OK; - } - - List<LanguageServerProtocol::CompletionItem> items = collectMembers(version, parsedModule, baseExpr); - m_connection->sendResult(&items, responseId); + m_connection->sendResult(NullResponse::get(), responseId); return SLANG_OK; } @@ -628,6 +539,13 @@ SlangResult LanguageServer::semanticTokens( { String canonicalPath = uriToCanonicalPath(args.textDocument.uri); + RefPtr<DocumentVersion> doc; + if (!m_workspace->openedDocuments.TryGetValue(canonicalPath, doc)) + { + m_connection->sendResult(NullResponse::get(), responseId); + return SLANG_OK; + } + auto version = m_workspace->getCurrentVersion(); Module* parsedModule = version->getOrLoadModule(canonicalPath); if (!parsedModule) @@ -636,7 +554,18 @@ SlangResult LanguageServer::semanticTokens( return SLANG_OK; } - auto tokens = getSemanticTokens(version->linkage, parsedModule, canonicalPath.getUnownedSlice()); + auto tokens = getSemanticTokens(version->linkage, parsedModule, canonicalPath.getUnownedSlice(), doc.Ptr()); + for (auto& token : tokens) + { + Index line, col; + doc->oneBasedUTF8LocToZeroBasedUTF16Loc(token.line, token.col, line, col); + Index lineEnd, colEnd; + doc->oneBasedUTF8LocToZeroBasedUTF16Loc( + token.line, token.col + token.length, lineEnd, colEnd); + token.line = (int)line; + token.col = (int)col; + token.length = (int)(colEnd - col); + } SemanticTokens response; response.resultId = ""; response.data = getEncodedTokens(tokens); @@ -648,6 +577,14 @@ SlangResult LanguageServer::signatureHelp( const LanguageServerProtocol::SignatureHelpParams& args, const JSONValue& responseId) { String canonicalPath = uriToCanonicalPath(args.textDocument.uri); + RefPtr<DocumentVersion> doc; + if (!m_workspace->openedDocuments.TryGetValue(canonicalPath, doc)) + { + m_connection->sendResult(NullResponse::get(), responseId); + return SLANG_OK; + } + Index line, col; + doc->zeroBasedUTF16LocToOneBasedUTF8Loc(args.position.line, args.position.character, line, col); auto version = m_workspace->getCurrentVersion(); Module* parsedModule = version->getOrLoadModule(canonicalPath); @@ -658,12 +595,13 @@ SlangResult LanguageServer::signatureHelp( } auto findResult = findASTNodesAt( + doc.Ptr(), version->linkage->getSourceManager(), parsedModule->getModuleDecl(), ASTLookupType::Invoke, canonicalPath.getUnownedSlice(), - args.position.line + 1, - args.position.character + 1); + line, + col); if (findResult.getCount() == 0) { @@ -673,6 +611,7 @@ SlangResult LanguageServer::signatureHelp( AppExprBase* appExpr = nullptr; auto& declPath = findResult[0].path; + Loc currentLoc = {args.position.line + 1, args.position.character + 1}; for (Index i = declPath.getCount() - 1; i >= 0; i--) { if (auto expr = as<AppExprBase>(declPath[i])) @@ -681,8 +620,14 @@ SlangResult LanguageServer::signatureHelp( // This allows us to skip the invoke expr nodes for operators/implcit casts. if (expr->argumentDelimeterLocs.getCount()) { - appExpr = expr; - break; + auto start = Loc::fromSourceLoc(version->linkage->getSourceManager(), expr->argumentDelimeterLocs.getFirst()); + auto end = Loc::fromSourceLoc( + version->linkage->getSourceManager(), expr->argumentDelimeterLocs.getLast()); + if (start < currentLoc && currentLoc <= end) + { + appExpr = expr; + break; + } } } } @@ -746,7 +691,18 @@ SlangResult LanguageServer::signatureHelp( }; if (auto declRefExpr = as<DeclRefExpr>(funcExpr)) { - addDeclRef(declRefExpr->declRef); + if (auto aggDecl = as<AggTypeDecl>(declRefExpr->declRef.getDecl())) + { + // Look for initializers + for (auto member : aggDecl->getMembersOfType<ConstructorDecl>()) + { + addDeclRef(DeclRef<Decl>(member, declRefExpr->declRef.substitutions)); + } + } + else + { + addDeclRef(declRefExpr->declRef); + } } else if (auto overloadedExpr = as<OverloadedExpr>(funcExpr)) { @@ -773,171 +729,6 @@ SlangResult LanguageServer::signatureHelp( return SLANG_OK; } - -List<LanguageServerProtocol::CompletionItem> LanguageServer::collectMembers(WorkspaceVersion* version, Module* module, Expr* baseExpr) -{ - List<LanguageServerProtocol::CompletionItem> result; - auto linkage = version->linkage; - Type* type = baseExpr->type.type; - if (auto typeType = as<TypeType>(type)) - { - type = typeType->type; - } - version->currentCompletionItems.clear(); - if (type) - { - if (as<ArithmeticExpressionType>(type)) - { - // Hard code members for vector and matrix types. - result.clear(); - version->currentCompletionItems.clear(); - int elementCount = 0; - Type* elementType = nullptr; - const char* memberNames[4] = {"x", "y", "z", "w"}; - if (auto vectorType = as<VectorExpressionType>(type)) - { - if (auto elementCountVal = as<ConstantIntVal>(vectorType->elementCount)) - { - elementCount = (int)elementCountVal->value; - elementType = vectorType->elementType; - } - } - else if (auto matrixType = as<MatrixExpressionType>(type)) - { - if (auto elementCountVal = as<ConstantIntVal>(matrixType->getRowCount())) - { - elementCount = (int)elementCountVal->value; - elementType = matrixType->getRowType(); - } - } - String typeStr; - if (elementType) - typeStr = elementType->toString(); - for (int i = 0; i < elementCount; i++) - { - CompletionItem item; - item.data = 0; - item.detail = typeStr; - item.kind = LanguageServerProtocol::kCompletionItemKindVariable; - item.label = memberNames[i]; - result.add(item); - } - } - else - { - DiagnosticSink sink; - MemberCollectingContext context(linkage, module, &sink); - context.astBuilder = linkage->getASTBuilder(); - collectMembersInType(&context, type); - HashSet<String> deduplicateSet; - for (auto member : context.members) - { - CompletionItem item; - item.label = member->getName()->text; - item.kind = 0; - if (as<TypeConstraintDecl>(member)) - { - continue; - } - if (as<ConstructorDecl>(member)) - { - continue; - } - if (as<SubscriptDecl>(member)) - { - continue; - } - - if (item.label.startsWith("$")) - continue; - if (!deduplicateSet.Add(item.label)) - continue; - - if (as<StructDecl>(member)) - { - item.kind = LanguageServerProtocol::kCompletionItemKindStruct; - } - else if (as<ClassDecl>(member)) - { - item.kind = LanguageServerProtocol::kCompletionItemKindClass; - } - else if (as<InterfaceDecl>(member)) - { - item.kind = LanguageServerProtocol::kCompletionItemKindInterface; - } - else if (as<SimpleTypeDecl>(member)) - { - item.kind = LanguageServerProtocol::kCompletionItemKindClass; - } - else if (as<PropertyDecl>(member)) - { - item.kind = LanguageServerProtocol::kCompletionItemKindProperty; - } - else if (as<EnumDecl>(member)) - { - item.kind = LanguageServerProtocol::kCompletionItemKindEnum; - } - else if (as<VarDeclBase>(member)) - { - item.kind = LanguageServerProtocol::kCompletionItemKindVariable; - } - else if (as<EnumCaseDecl>(member)) - { - item.kind = LanguageServerProtocol::kCompletionItemKindEnumMember; - } - else if (as<CallableDecl>(member)) - { - item.kind = LanguageServerProtocol::kCompletionItemKindMethod; - } - else if (as<AssocTypeDecl>(member)) - { - item.kind = LanguageServerProtocol::kCompletionItemKindClass; - } - item.data = String(version->currentCompletionItems.getCount()); - result.add(item); - version->currentCompletionItems.add(member); - } - } - - for (auto& item : result) - { - switch (item.kind) - { - case LanguageServerProtocol::kCompletionItemKindMethod: - item.commitCharacters.add("("); - item.commitCharacters.add("["); - item.commitCharacters.add(" "); - break; - default: - item.commitCharacters.add("("); - item.commitCharacters.add(")"); - item.commitCharacters.add("."); - item.commitCharacters.add(";"); - item.commitCharacters.add(":"); - item.commitCharacters.add(","); - item.commitCharacters.add("<"); - item.commitCharacters.add(">"); - item.commitCharacters.add("["); - item.commitCharacters.add("]"); - item.commitCharacters.add("{"); - item.commitCharacters.add("}"); - item.commitCharacters.add("-"); - item.commitCharacters.add("*"); - item.commitCharacters.add("/"); - item.commitCharacters.add("%"); - item.commitCharacters.add("+"); - item.commitCharacters.add("="); - item.commitCharacters.add("&"); - item.commitCharacters.add("|"); - item.commitCharacters.add("!"); - item.commitCharacters.add(" "); - break; - } - } - } - return result; -} - void LanguageServer::publishDiagnostics() { time_t timeNow = 0; @@ -982,28 +773,236 @@ void LanguageServer::publishDiagnostics() } } +void LanguageServer::updatePredefinedMacros(const JSONValue& macros) +{ + if (macros.isValid()) + { + auto container = m_connection->getContainer(); + JSONToNativeConverter converter(container, m_connection->getSink()); + List<String> predefinedMacros; + if (SLANG_SUCCEEDED(converter.convert(macros, &predefinedMacros))) + { + if (m_workspace->updatePredefinedMacros(predefinedMacros)) + { + m_connection->sendCall( + UnownedStringSlice("workspace/semanticTokens/refresh"), JSONValue::makeInt(0)); + } + } + } +} + +void LanguageServer::sendConfigRequest() +{ + ConfigurationParams args; + ConfigurationItem item; + item.section = "slang.predefinedMacros"; + args.items.add(item); + m_connection->sendCall( + ConfigurationParams::methodName, &args, JSONValue::makeInt(kConfigResponseId)); +} + +void LanguageServer::registerCapability(const char* methodName) +{ + RegistrationParams args; + Registration reg; + reg.method = methodName; + reg.id = reg.method; + args.registrations.add(reg); + m_connection->sendCall( + UnownedStringSlice("client/registerCapability"), &args, JSONValue::makeInt(999)); +} + +void LanguageServer::logMessage(int type, String message) +{ + LanguageServerProtocol::LogMessageParams args; + args.type = type; + args.message = message; + m_connection->sendCall(LanguageServerProtocol::LogMessageParams::methodName, &args); +} + +SlangResult LanguageServer::queueJSONCall(JSONRPCCall call) +{ + Command cmd; + cmd.id = PersistentJSONValue(call.id, m_connection->getContainer()); + cmd.method = call.method; + if (call.method == DidOpenTextDocumentParams::methodName) + { + DidOpenTextDocumentParams args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + cmd.openDocArgs = args; + } + else if (call.method == DidCloseTextDocumentParams::methodName) + { + DidCloseTextDocumentParams args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + cmd.closeDocArgs = args; + } + else if (call.method == DidChangeTextDocumentParams::methodName) + { + DidChangeTextDocumentParams args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + cmd.changeDocArgs = args; + } + else if (call.method == HoverParams::methodName) + { + HoverParams args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + cmd.hoverArgs = args; + } + else if (call.method == DefinitionParams::methodName) + { + DefinitionParams args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + cmd.definitionArgs = args; + } + else if (call.method == CompletionParams::methodName) + { + CompletionParams args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + cmd.completionArgs = args; + } + else if (call.method == SemanticTokensParams::methodName) + { + SemanticTokensParams args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + cmd.semanticTokenArgs = args; + } + else if (call.method == SignatureHelpParams::methodName) + { + SignatureHelpParams args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + cmd.signatureHelpArgs = args; + } + else if (call.method == "completionItem/resolve") + { + CompletionItem args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + cmd.completionResolveArgs = args; + } + else if (call.method == DidChangeConfigurationParams::methodName) + { + DidChangeConfigurationParams args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + // We need to process it now instead of sending to queue. + // This is because there is reference to JSONValue that is only available here. + return didChangeConfiguration(args); + } + else if (call.method == "$/cancelRequest") + { + CancelParams args; + SLANG_RETURN_ON_FAIL(m_connection->toNativeArgsOrSendError(call.params, &args, call.id)); + cmd.cancelArgs = args; + } + commands.add(_Move(cmd)); + return SLANG_OK; +} + +SlangResult LanguageServer::runCommand(Command& call) +{ + // Do different things + if (call.method == DidOpenTextDocumentParams::methodName) + { + return didOpenTextDocument(call.openDocArgs.get()); + } + else if (call.method == DidCloseTextDocumentParams::methodName) + { + return didCloseTextDocument(call.closeDocArgs.get()); + } + else if (call.method == DidChangeTextDocumentParams::methodName) + { + return didChangeTextDocument(call.changeDocArgs.get()); + } + else if (call.method == HoverParams::methodName) + { + return hover(call.hoverArgs.get(), call.id); + } + else if (call.method == DefinitionParams::methodName) + { + return gotoDefinition(call.definitionArgs.get(), call.id); + } + else if (call.method == CompletionParams::methodName) + { + return completion(call.completionArgs.get(), call.id); + } + else if (call.method == SemanticTokensParams::methodName) + { + return semanticTokens(call.semanticTokenArgs.get(), call.id); + } + else if (call.method == SignatureHelpParams::methodName) + { + return signatureHelp(call.signatureHelpArgs.get(), call.id); + } + else if (call.method == "completionItem/resolve") + { + return completionResolve(call.completionResolveArgs.get(), call.id); + } + else if (call.method == DidChangeConfigurationParams::methodName) + { + return didChangeConfiguration(call.changeConfigArgs.get()); + } + else if (call.method.startsWith("$/")) + { + // Ignore. + return SLANG_OK; + } + else + { + return m_connection->sendError(JSONRPC::ErrorCode::MethodNotFound, call.id); + } +} + +void LanguageServer::processCommands() +{ + HashSet<int64_t> canceledIDs; + for (auto& cmd : commands) + { + if (cmd.method == "$/cancelRequest") + { + auto id = cmd.cancelArgs.get().id; + if (id > 0) + { + canceledIDs.Add(id); + } + } + } + const int kErrorRequestCanceled = -32800; + for (auto& cmd : commands) + { + if (cmd.id.getKind() == JSONValue::Kind::Integer && canceledIDs.Contains(cmd.id.asInteger())) + { + m_connection->sendError((JSONRPC::ErrorCode)kErrorRequestCanceled, cmd.id); + } + else + { + runCommand(cmd); + } + } +} + SlangResult LanguageServer::didCloseTextDocument(const DidCloseTextDocumentParams& args) { String canonicalPath = uriToCanonicalPath(args.textDocument.uri); - m_workspace->openedDocuments.Remove(canonicalPath); - m_workspace->invalidate(); + m_workspace->closeDoc(canonicalPath); resetDiagnosticUpdateTime(); return SLANG_OK; } SlangResult LanguageServer::didChangeTextDocument(const DidChangeTextDocumentParams& args) { String canonicalPath = uriToCanonicalPath(args.textDocument.uri); - - RefPtr<DocumentVersion> doc; - if (m_workspace->openedDocuments.TryGetValue(canonicalPath, doc)) - { - doc->setText(args.contentChanges[0].text.getUnownedSlice()); - } - m_workspace->invalidate(); + for (auto change : args.contentChanges) + m_workspace->changeDoc(canonicalPath, change.range, change.text); resetDiagnosticUpdateTime(); return SLANG_OK; } +SlangResult LanguageServer::didChangeConfiguration( + const LanguageServerProtocol::DidChangeConfigurationParams& args) +{ + SLANG_UNUSED(args); + sendConfigRequest(); + return SLANG_OK; +} + void LanguageServer::update() { if (!m_workspace) @@ -1013,24 +1012,38 @@ void LanguageServer::update() SlangResult LanguageServer::execute() { - m_connection = new JSONRPCConnection(); m_connection->initWithStdStreams(); while (m_connection->isActive() && !m_quit) { // Consume all messages first. + commands.clear(); + auto start = platform::PerformanceCounter::now(); while (true) { m_connection->tryReadMessage(); if (!m_connection->hasMessage()) break; - const SlangResult res = _executeSingle(); - + parseNextMessage(); } - + auto parseTime = platform::PerformanceCounter::getElapsedTimeInSeconds(start); + auto parseEnd = platform::PerformanceCounter::now(); + processCommands(); // Now we can use this time to reparse user's code, report diagnostics, etc. update(); + auto workTime = platform::PerformanceCounter::getElapsedTimeInSeconds(parseEnd); + + if (commands.getCount() > 0 && m_initialized) + { + StringBuilder msgBuilder; + msgBuilder << "Server processed " << commands.getCount() << " commands, parsed in " + << String(int(parseTime * 1000)) << "ms, executed in " + << String(int(workTime * 1000)) << "ms"; + logMessage(3, msgBuilder.ProduceString()); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } return SLANG_OK; diff --git a/source/slang/slang-language-server.h b/source/slang/slang-language-server.h index e3abfb2e5..f9ba2ddfc 100644 --- a/source/slang/slang-language-server.h +++ b/source/slang/slang-language-server.h @@ -1,8 +1,124 @@ #pragma once #include "../../slang.h" +#include "../compiler-core/slang-json-rpc.h" +#include "../compiler-core/slang-json-rpc-connection.h" + +#include "slang-workspace-version.h" namespace Slang { +ArrayView<const char*> getCommitChars(); + +struct Command +{ + PersistentJSONValue id; + String method; + + template <typename T> struct Optional + { + public: + T* value = nullptr; + bool isValid() { return value != nullptr; } + T& operator=(const T& val) + { + delete value; + value = new T(val); + return *value; + } + T& operator=(Optional&& other) + { + if (other.isValid()) + *this = (other.get()); + other.value = nullptr; + return *value; + } + T& get() + { + SLANG_ASSERT(isValid()); + return *value; + } + Optional() = default; + Optional(const Optional& other) + { + if (other.isValid()) + *this = (other.get()); + } + Optional(Optional&& other) + { + if (other.isValid()) + *this = (other.get()); + other.value = nullptr; + } + + ~Optional() { delete value; } + }; + + Optional<LanguageServerProtocol::CompletionParams> completionArgs; + Optional<LanguageServerProtocol::CompletionItem> completionResolveArgs; + Optional<LanguageServerProtocol::DidChangeConfigurationParams> changeConfigArgs; + Optional<LanguageServerProtocol::SignatureHelpParams> signatureHelpArgs; + Optional<LanguageServerProtocol::DefinitionParams> definitionArgs; + Optional<LanguageServerProtocol::SemanticTokensParams> semanticTokenArgs; + Optional<LanguageServerProtocol::HoverParams> hoverArgs; + Optional<LanguageServerProtocol::DidOpenTextDocumentParams> openDocArgs; + Optional<LanguageServerProtocol::DidChangeTextDocumentParams> changeDocArgs; + Optional<LanguageServerProtocol::DidCloseTextDocumentParams> closeDocArgs; + Optional<LanguageServerProtocol::CancelParams> cancelArgs; +}; + +class LanguageServer +{ +private: + static const int kConfigResponseId = 0x1213; + +public: + bool m_initialized = false; + RefPtr<JSONRPCConnection> m_connection; + ComPtr<slang::IGlobalSession> m_session; + RefPtr<Workspace> m_workspace; + Dictionary<String, String> m_lastPublishedDiagnostics; + time_t m_lastDiagnosticUpdateTime = 0; + bool m_quit = false; + List<LanguageServerProtocol::WorkspaceFolder> m_workspaceFolders; + + SlangResult init(const LanguageServerProtocol::InitializeParams& args); + SlangResult execute(); + void update(); + SlangResult didOpenTextDocument(const LanguageServerProtocol::DidOpenTextDocumentParams& args); + SlangResult didCloseTextDocument( + const LanguageServerProtocol::DidCloseTextDocumentParams& args); + SlangResult didChangeTextDocument( + const LanguageServerProtocol::DidChangeTextDocumentParams& args); + SlangResult didChangeConfiguration( + const LanguageServerProtocol::DidChangeConfigurationParams& args); + SlangResult hover(const LanguageServerProtocol::HoverParams& args, const JSONValue& responseId); + SlangResult gotoDefinition( + const LanguageServerProtocol::DefinitionParams& args, const JSONValue& responseId); + SlangResult completion( + const LanguageServerProtocol::CompletionParams& args, const JSONValue& responseId); + SlangResult completionResolve( + const LanguageServerProtocol::CompletionItem& args, const JSONValue& responseId); + SlangResult semanticTokens( + const LanguageServerProtocol::SemanticTokensParams& args, const JSONValue& responseId); + SlangResult signatureHelp( + const LanguageServerProtocol::SignatureHelpParams& args, const JSONValue& responseId); + +private: + SlangResult parseNextMessage(); + slang::IGlobalSession* getOrCreateGlobalSession(); + void resetDiagnosticUpdateTime(); + void publishDiagnostics(); + void updatePredefinedMacros(const JSONValue& macros); + void sendConfigRequest(); + void registerCapability(const char* methodName); + void logMessage(int type, String message); + + List<Command> commands; + SlangResult queueJSONCall(JSONRPCCall call); + SlangResult runCommand(Command& cmd); + void processCommands(); +}; + SLANG_API SlangResult runLanguageServer(); } // namespace Slang diff --git a/source/slang/slang-mangled-lexer.cpp b/source/slang/slang-mangled-lexer.cpp index 32a8fa741..4920cfce3 100644 --- a/source/slang/slang-mangled-lexer.cpp +++ b/source/slang/slang-mangled-lexer.cpp @@ -158,8 +158,14 @@ UnownedStringSlice MangledLexer::readSimpleName() } } -UnownedStringSlice MangledLexer::readRawStringSegment() +String MangledLexer::readRawStringSegment() { + bool escapeMode = false; + if (peekChar() == 'R') + { + escapeMode = true; + nextChar(); + } // Read the length part UInt count = readCount(); if (count > UInt(m_end - m_cursor)) @@ -170,9 +176,52 @@ UnownedStringSlice MangledLexer::readRawStringSegment() auto result = UnownedStringSlice(m_cursor, m_cursor + count); m_cursor += count; + if (escapeMode) + return unescapeString(result); return result; } +String MangledLexer::unescapeString(UnownedStringSlice str) +{ + StringBuilder sb; + Index cursor = 0; + while (cursor < str.getLength()) + { + auto ch = str[cursor]; + auto nextCh = 0; + if (cursor + 1 < str.getLength()) + { + nextCh = str[cursor + 1]; + } + if (ch == '_' && nextCh == 'u') + { + sb.appendChar('_'); + cursor += 2; + } + else if (ch == '_') + { + cursor++; + Int charValue = 0; + while (cursor < str.getLength()) + { + auto current = str[cursor]; + if (current == 'x') + break; + charValue = charValue * 16 + CharUtil::getHexDigitValue(current); + cursor++; + } + sb.appendChar((char)charValue); + cursor++; + } + else + { + sb.appendChar(ch); + cursor++; + } + } + return sb.ProduceString(); +} + UInt MangledLexer::readParamCount() { _expect("p"); @@ -183,25 +232,24 @@ UInt MangledLexer::readParamCount() /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MangledNameParser !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ -/* static */SlangResult MangledNameParser::parseModuleName(const UnownedStringSlice& in, UnownedStringSlice& outModuleName) +/* static */SlangResult MangledNameParser::parseModuleName(const UnownedStringSlice& in, String& outModuleName) { MangledLexer lexer(in); - { switch (lexer.peekChar()) { - case 'T': - case 'G': - case 'V': - { - lexer.nextChar(); - break; - } - default: break; + case 'T': + case 'G': + case 'V': + { + lexer.nextChar(); + break; + } + default: break; } } - UnownedStringSlice name = lexer.readRawStringSegment(); + auto name = lexer.readRawStringSegment(); if (name.getLength() == 0) { return SLANG_FAIL; diff --git a/source/slang/slang-mangled-lexer.h b/source/slang/slang-mangled-lexer.h index 7d096e45e..c9d051259 100644 --- a/source/slang/slang-mangled-lexer.h +++ b/source/slang/slang-mangled-lexer.h @@ -24,7 +24,7 @@ public: SLANG_INLINE void readSimpleIntVal(); - UnownedStringSlice readRawStringSegment(); + String readRawStringSegment(); void readNamedType(); @@ -47,7 +47,8 @@ public: // Returns the current character and moves to next character. char nextChar() { return *m_cursor++; } - + static String unescapeString(UnownedStringSlice str); + /// Ctor SLANG_FORCE_INLINE MangledLexer(const UnownedStringSlice& slice); @@ -126,7 +127,7 @@ SLANG_INLINE void MangledLexer::_expect(char c) struct MangledNameParser { /// Tries to extract the module name from this mangled name. - static SlangResult parseModuleName(const UnownedStringSlice& in, UnownedStringSlice& outModuleName); + static SlangResult parseModuleName(const UnownedStringSlice& in, String& outModuleName); }; } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index b2179c1af..38317009d 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -3105,7 +3105,11 @@ namespace Slang static NodeBase* parseConstructorDecl(Parser* parser, void* /*userData*/) { ConstructorDecl* decl = parser->astBuilder->create<ConstructorDecl>(); - parser->FillPosition(decl); + + // Note: we leave the source location of this decl as invalid, to + // trigger the fallback logic that fills in the location of the + // `__init` keyword later. + parser->PushScope(decl); // TODO: we need to make sure that all initializers have @@ -3132,6 +3136,7 @@ namespace Slang AccessorDecl* decl = nullptr; auto loc = peekToken(parser).loc; + auto name = peekToken(parser).getName(); if( AdvanceIf(parser, "get") ) { decl = parser->astBuilder->create<GetterDecl>(); @@ -3150,6 +3155,8 @@ namespace Slang return nullptr; } decl->loc = loc; + decl->nameAndLoc.name = name; + decl->nameAndLoc.loc = loc; _addModifiers(decl, modifiers); diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp index 344b4aa02..09aea6c8b 100644 --- a/source/slang/slang-serialize-container.cpp +++ b/source/slang/slang-serialize-container.cpp @@ -350,7 +350,7 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( { auto startChunk = chunk; - RefPtr<ASTBuilder> astBuilder; + RefPtr<ASTBuilder> astBuilder = options.astBuilder; NodeBase* astRootNode = nullptr; RefPtr<IRModule> irModule; @@ -385,8 +385,10 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( StringBuilder buf; buf << "tu" << out.modules.getCount(); - - astBuilder = new ASTBuilder(options.sharedASTBuilder, buf.ProduceString()); + if (!astBuilder) + { + astBuilder = new ASTBuilder(options.sharedASTBuilder, buf.ProduceString()); + } DefaultSerialObjectFactory objectFactory(astBuilder); @@ -419,7 +421,7 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( { UnownedStringSlice mangledName = reader.getStringSlice(SerialIndex(i)); - UnownedStringSlice moduleName; + String moduleName; SLANG_RETURN_ON_FAIL(MangledNameParser::parseModuleName(mangledName, moduleName)); // If we already have looked up this module and it has the same name just use what we have @@ -457,8 +459,11 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( options.sink->diagnose(SourceLoc::fromRaw(0), Diagnostics::unableToFindSymbolInModule, mangledName, moduleName); } - // If didn't find the export then we are done - return SLANG_FAIL; + // If didn't find the export then we create an UnresolvedDecl node to represent the error. + auto unresolved = astBuilder->create<UnresolvedDecl>(); + unresolved->nameAndLoc.name = + options.linkage->getNamePool()->getName(mangledName); + nodeBase = unresolved; } // set the result diff --git a/source/slang/slang-serialize-container.h b/source/slang/slang-serialize-container.h index 5368f62e6..9ee0625bf 100644 --- a/source/slang/slang-serialize-container.h +++ b/source/slang/slang-serialize-container.h @@ -89,6 +89,7 @@ struct SerialContainerUtil SourceManager* sourceManager = nullptr; NamePool* namePool = nullptr; SharedASTBuilder* sharedASTBuilder = nullptr; + ASTBuilder* astBuilder = nullptr; // Optional. If not provided will create one in SerialContainerData. Linkage* linkage = nullptr; DiagnosticSink* sink = nullptr; }; diff --git a/source/slang/slang-serialize-factory.cpp b/source/slang/slang-serialize-factory.cpp index a6b7e0ba8..351742e60 100644 --- a/source/slang/slang-serialize-factory.cpp +++ b/source/slang/slang-serialize-factory.cpp @@ -75,14 +75,12 @@ SerialIndex ModuleSerialFilter::writePointer(SerialWriter* writer, const NodeBas if (Decl* decl = as<Decl>(ptr)) { ModuleDecl* moduleDecl = findModuleForDecl(decl); - SLANG_ASSERT(moduleDecl); if (moduleDecl && moduleDecl != m_moduleDecl) { ASTBuilder* astBuilder = m_moduleDecl->module->getASTBuilder(); // It's a reference to a declaration in another module, so first get the symbol name. String mangledName = getMangledName(astBuilder, decl); - // Add as an import symbol return writer->addImportSymbol(mangledName); } @@ -110,6 +108,8 @@ SerialIndex ModuleSerialFilter::writePointer(SerialWriter* writer, const NodeBas // // For now we just ignore all stmts + // TODO(yong): We should by default serialize everything. The logic to skip bodies need to be + // behind a option flag. if (Stmt* stmt = as<Stmt>(ptr)) { // diff --git a/source/slang/slang-serialize-source-loc.cpp b/source/slang/slang-serialize-source-loc.cpp index 331ca96ae..35473a1be 100644 --- a/source/slang/slang-serialize-source-loc.cpp +++ b/source/slang/slang-serialize-source-loc.cpp @@ -286,13 +286,13 @@ SlangResult SerialSourceLocReader::read(const SerialSourceLocData* serialData, S const auto& lineInfo = lineInfos[lineInfoIndex]; const uint32_t offset = lineInfo.m_lineStartOffset; - SLANG_ASSERT(offset > 0); const int finishIndex = int(lineInfo.m_lineIndex); SLANG_ASSERT(finishIndex < numLines); for (; lineIndex < finishIndex; ++lineIndex) { + SLANG_ASSERT(offset > 0); lineBreakOffsets[lineIndex] = offset - 1; } lineBreakOffsets[lineIndex] = offset; diff --git a/source/slang/slang-workspace-version.cpp b/source/slang/slang-workspace-version.cpp index 1e92c31ca..d93be86b7 100644 --- a/source/slang/slang-workspace-version.cpp +++ b/source/slang/slang-workspace-version.cpp @@ -2,6 +2,7 @@ #include "../core/slang-io.h" #include "../core/slang-file-system.h" #include "../compiler-core/slang-lexer.h" +#include "slang-serialize-container.h" namespace Slang { @@ -32,10 +33,75 @@ DocumentVersion* Workspace::openDoc(String path, String text) doc->setURI(URI::fromLocalFilePath(path.getUnownedSlice())); openedDocuments[path] = doc; searchPaths.Add(Path::getParentDirectory(path)); + moduleCache.invalidate(path); invalidate(); return doc.Ptr(); } +void Workspace::changeDoc(const String& path, LanguageServerProtocol::Range range, const String& text) +{ + RefPtr<DocumentVersion> doc; + if (openedDocuments.TryGetValue(path, doc)) + { + Index line, col; + doc->zeroBasedUTF16LocToOneBasedUTF8Loc(range.start.line, range.start.character, line, col); + auto startOffset = doc->getOffset(line, col); + doc->zeroBasedUTF16LocToOneBasedUTF8Loc(range.end.line, range.end.character, line, col); + auto endOffset = doc->getOffset(line, col); + auto originalText = doc->getText().getUnownedSlice(); + StringBuilder newText; + newText << originalText.head(startOffset) << text << originalText.tail(endOffset); + doc->setText(newText.ProduceString()); + } + moduleCache.invalidate(path); + invalidate(); +} + +void Workspace::closeDoc(const String& path) +{ + moduleCache.invalidate(path); + openedDocuments.Remove(path); + invalidate(); +} + +bool Workspace::updatePredefinedMacros(List<String> macros) +{ + List<OnwedPreprocessorMacroDefinition> newDefs; + for (auto macro : macros) + { + auto index = macro.indexOf('='); + OnwedPreprocessorMacroDefinition def; + def.name = macro.getUnownedSlice().head(index).trim(); + if (index != -1) + { + def.value = macro.getUnownedSlice().tail(index + 1).trim(); + } + newDefs.add(def); + } + + bool changed = false; + if (newDefs.getCount() != predefinedMacros.getCount()) + changed = true; + else + { + for (Index i = 0; i < newDefs.getCount(); i++) + { + if (newDefs[i].name != predefinedMacros[i].name || + newDefs[i].value != predefinedMacros[i].value) + { + changed = true; + break; + } + } + } + if (changed) + { + predefinedMacros = _Move(newDefs); + invalidate(); + } + return changed; +} + void Workspace::init(List<URI> rootDirURI, slang::IGlobalSession* globalSession) { for (auto uri : rootDirURI) @@ -73,7 +139,7 @@ void Workspace::init(List<URI> rootDirURI, slang::IGlobalSession* globalSession) void Workspace::invalidate() { currentVersion = nullptr; } -void parseDiagnostics(Dictionary<String, DocumentDiagnostics>& diagnostics, String compilerOutput) +void WorkspaceVersion::parseDiagnostics(String compilerOutput) { List<UnownedStringSlice> lines; StringUtil::calcLines(compilerOutput.getUnownedSlice(), lines); @@ -95,12 +161,12 @@ void parseDiagnostics(Dictionary<String, DocumentDiagnostics>& diagnostics, Stri int lineLoc = StringUtil::parseIntAndAdvancePos(line, pos); if (lineLoc == 0) lineLoc = 1; - diagnostic.range.end.line = diagnostic.range.start.line = lineLoc - 1; + diagnostic.range.end.line = diagnostic.range.start.line = lineLoc; pos++; int colLoc = StringUtil::parseIntAndAdvancePos(line, pos); if (colLoc == 0) colLoc = 1; - diagnostic.range.end.character = diagnostic.range.start.character = colLoc - 1; + diagnostic.range.end.character = diagnostic.range.start.character = colLoc; if (pos >= line.getLength()) continue; line = line.tail(colonIndex + 3); @@ -133,7 +199,31 @@ void parseDiagnostics(Dictionary<String, DocumentDiagnostics>& diagnostics, Stri auto tokenLength = StringUtil::parseIntAndAdvancePos(line, pos); diagnostic.range.end.character += tokenLength; } + + if (auto doc = workspace->openedDocuments.TryGetValue(fileName)) + { + // If the file is open, translate to UTF16 positions using the document. + Index lineUTF16, colUTF16; + doc->Ptr()->oneBasedUTF8LocToZeroBasedUTF16Loc( + diagnostic.range.start.line, diagnostic.range.start.character, lineUTF16, colUTF16); + diagnostic.range.start.line = (int)lineUTF16; + diagnostic.range.start.character = (int)colUTF16; + doc->Ptr()->oneBasedUTF8LocToZeroBasedUTF16Loc( + diagnostic.range.end.line, diagnostic.range.end.character, lineUTF16, colUTF16); + diagnostic.range.end.line = (int)lineUTF16; + diagnostic.range.end.character = (int)colUTF16; + } + else + { + // Otherwise, just return an 0-based position. + diagnostic.range.start.line--; + diagnostic.range.start.character--; + diagnostic.range.end.line--; + diagnostic.range.end.character--; + } diagnosticList.messages.Add(diagnostic); + if (diagnosticList.messages.Count() >= 1000) + break; } } @@ -148,17 +238,33 @@ RefPtr<WorkspaceVersion> Workspace::createWorkspaceVersion() slang::TargetDesc targetDesc = {}; targetDesc.profile = slangGlobalSession->findProfile("sm_6_6"); desc.targets = &targetDesc; - List<const char*> searchPathsRaw; - for (auto path : searchPaths) searchPathsRaw.add(path.getBuffer()); desc.searchPaths = searchPathsRaw.getBuffer(); desc.searchPathCount = searchPathsRaw.getCount(); + desc.preprocessorMacroCount = predefinedMacros.getCount(); + List<slang::PreprocessorMacroDesc> macroDescs; + for (auto& macro : predefinedMacros) + { + slang::PreprocessorMacroDesc macroDesc; + macroDesc.name = macro.name.getBuffer(); + macroDesc.value = macro.value.getBuffer(); + macroDescs.add(macroDesc); + } + desc.preprocessorMacros = macroDescs.getBuffer(); + ComPtr<slang::ISession> session; slangGlobalSession->createSession(desc, session.writeRef()); version->linkage = static_cast<Linkage*>(session.get()); + // TODO(yong): module cache does improves performance by 30%. However there are some issues + // that prevents the deserialization to resolve the imported decls from the correct module. + // This doesn't lead to crash, but may cause problems. We can enable this when the issues + // are fixed. +#if 0 + version->linkage->setModuleCache(&moduleCache); +#endif return version; } @@ -194,14 +300,79 @@ void* Workspace::getInterface(const Guid& uuid) void DocumentVersion::setText(const String& newText) { text = newText; - lineBreaks.clear(); - for (Index i = 0; i < newText.getLength(); i++) + StringUtil::calcLines(text.getUnownedSlice(), lines); + utf16CharStarts.clear(); +} +ArrayView<Index> DocumentVersion::getUTF16Boundaries(Index line) +{ + if (!utf16CharStarts.getCount()) + { + for (auto slice : lines) + { + List<Index> bounds; + Index index = 0; + while (index < slice.getLength()) + { + auto startIndex = index; + const Char32 codePoint = getUnicodePointFromUTF8( + [&]() -> Byte + { + if (index < slice.getLength()) + return slice[index++]; + else + return '\0'; + }); + if (!codePoint) + break; + Char16 buffer[2]; + int count = encodeUnicodePointToUTF16Reversed(codePoint, buffer); + for (int i = 0; i < count; i++) + bounds.add(startIndex); + } + bounds.add(slice.getLength()); + utf16CharStarts.add(_Move(bounds)); + } + } + return line >= 1 && line <= utf16CharStarts.getCount() ? utf16CharStarts[line - 1].getArrayView() + : ArrayView<Index>(); +} + +void DocumentVersion::oneBasedUTF8LocToZeroBasedUTF16Loc( + Index inLine, Index inCol, Index& outLine, Index& outCol) +{ + Index rsLine = inLine - 1; + auto line = lines[rsLine]; + auto bounds = getUTF16Boundaries(inLine); + outLine = rsLine; + outCol = std::lower_bound(bounds.begin(), bounds.end(), inCol - 1) - bounds.begin(); +} + +void DocumentVersion::zeroBasedUTF16LocToOneBasedUTF8Loc( + Index inLine, Index inCol, Index& outLine, Index& outCol) +{ + outLine = inLine + 1; + auto bounds = getUTF16Boundaries(inLine + 1); + outCol = inCol >=0 && inCol < bounds.getCount()? bounds[inCol] + 1 : 0; +} + +static bool _isIdentifierChar(char ch) +{ + return ch >= 'a' && ch <= 'z' || ch >= 'A' && ch <= 'Z' || ch >= '0' && ch <= '9' || ch == '_'; +} + +int DocumentVersion::getTokenLength(Index line, Index col) +{ + auto offset = getOffset(line, col); + if (offset >= 0) { - if (newText[i] == '\n') - lineBreaks.add(i); + Index pos = offset; + for (; pos < text.getLength() && _isIdentifierChar(text[pos]); ++pos) + {} + return (int)(pos - offset); } - lineBreaks.add(newText.getLength()); + return 0; } + ASTMarkup* WorkspaceVersion::getOrCreateMarkupAST(ModuleDecl* module) { RefPtr<ASTMarkup> astMarkup; @@ -209,6 +380,7 @@ ASTMarkup* WorkspaceVersion::getOrCreateMarkupAST(ModuleDecl* module) return astMarkup.Ptr(); DiagnosticSink sink; astMarkup = new ASTMarkup(); + sink.setSourceManager(linkage->getSourceManager()); ASTMarkupUtil::extract(module, linkage->getSourceManager(), &sink, astMarkup.Ptr()); markupASTs[module] = astMarkup; return astMarkup.Ptr(); @@ -238,7 +410,7 @@ Module* WorkspaceVersion::getOrLoadModule(String path) if (diagnosticBlob) { auto diagnosticString = String((const char*)diagnosticBlob->getBufferPointer()); - parseDiagnostics(diagnostics, diagnosticString); + parseDiagnostics(diagnosticString); auto docDiagnostic = diagnostics.TryGetValue(path); if (docDiagnostic) docDiagnostic->originalOutput = diagnosticString; @@ -246,4 +418,63 @@ Module* WorkspaceVersion::getOrLoadModule(String path) return static_cast<Module*>(parsedModule); } +RefPtr<Module> SerializedModuleCache::tryLoadModule( + Linkage* linkage, String filePath) +{ + Path::getCanonical(filePath, filePath); + if (List<uint8_t>* rawData = serializedModules.TryGetValue(filePath)) + { + RefPtr<MemoryStreamBase> memStream = + new MemoryStreamBase(FileAccess::Read, rawData->getBuffer(), rawData->getCount()); + RiffContainer riffContainer; + RiffUtil::read(memStream.Ptr(), riffContainer); + SerialContainerData outData; + SerialContainerUtil::ReadOptions options; + options.linkage = linkage; + options.namePool = linkage->getNamePool(); + options.session = linkage->getSessionImpl(); + options.sharedASTBuilder = linkage->getASTBuilder()->getSharedASTBuilder(); + options.astBuilder = linkage->getASTBuilder(); + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + options.sink = &sink; + options.sourceManager = linkage->getSourceManager(); + SLANG_RETURN_NULL_ON_FAIL(SerialContainerUtil::read(&riffContainer, options, outData)); + if (outData.modules.getCount() == 1) + { + RefPtr<Module> module = new Module(linkage, linkage->getASTBuilder()); + auto moduleDecl = as<ModuleDecl>(outData.modules[0].astRootNode); + if (moduleDecl) + { + moduleDecl->module = module.Ptr(); + module->setModuleDecl(moduleDecl); + return module; + } + } + } + return nullptr; +} + +void SerializedModuleCache::storeModule( + Linkage* linkage, String filePath, RefPtr<Module> module) +{ + Path::getCanonical(filePath, filePath); + RiffContainer container; + SerialContainerUtil::WriteOptions options; + options.sourceManager = linkage->getSourceManager(); + options.compressionType = SerialCompressionType::None; + options.optionFlags = SerialOptionFlag::SourceLocation | SerialOptionFlag::ASTModule; + SerialContainerData data; + SerialContainerData::Module moduleData; + moduleData.astBuilder = linkage->getASTBuilder(); + moduleData.astRootNode = module->getModuleDecl(); + moduleData.irModule = nullptr; + data.modules.add(moduleData); + SerialContainerUtil::write(data, options, &container); + RefPtr<OwnedMemoryStream> memStream = new OwnedMemoryStream(FileAccess::Write); + RiffUtil::write(&container, memStream); + List<uint8_t> rawData; + memStream->swapContents(rawData); + serializedModules[filePath] = _Move(rawData); +} + } // namespace Slang diff --git a/source/slang/slang-workspace-version.h b/source/slang/slang-workspace-version.h index 2aa2619f1..ee4d51896 100644 --- a/source/slang/slang-workspace-version.h +++ b/source/slang/slang-workspace-version.h @@ -18,7 +18,8 @@ namespace Slang private: URI uri; String text; - List<Int> lineBreaks; + List<UnownedStringSlice> lines; + List<List<Index>> utf16CharStarts; public: void setURI(URI newURI) { @@ -27,43 +28,65 @@ namespace Slang URI getURI() { return uri; } const String& getText() { return text; } void setText(const String& newText); + + ArrayView<Index> getUTF16Boundaries(Index line); + + void oneBasedUTF8LocToZeroBasedUTF16Loc( + Index inLine, Index inCol, Index& outLine, Index& outCol); + void zeroBasedUTF16LocToOneBasedUTF8Loc( + Index inLine, Index inCol, Index& outLine, Index& outCol); + + // Get starting offset of line. + Index getLineStart(UnownedStringSlice line) { return line.begin() - text.begin(); } + + // Get offset from 1-based, utf-8 encoding location. Index getOffset(Index lineIndex, Index colIndex) { if(lineIndex < 0) return -1; - if (lineIndex - 1 >= lineBreaks.getCount()) + if (lineIndex - 1 >= lines.getCount()) return -1; - if (lineBreaks.getCount() == 0) + if (lines.getCount() == 0) return -1; - Index lineStart = lineIndex >= 2 ? lineBreaks[lineIndex - 2] : 0; + Index lineStart = lineIndex >= 1 ? getLineStart(lines[lineIndex - 1]) : 0; return lineStart + colIndex - 1; } + + // Get 1-based, utf-8 encoding location from offset. void offsetToLineCol(Index offset, Index& line, Index& col) { - auto firstGreater = std::upper_bound(lineBreaks.begin(), lineBreaks.end(), offset); - line = Index(firstGreater - lineBreaks.begin() + 1); - if (firstGreater == lineBreaks.begin()) + auto firstGreater = std::upper_bound( + lines.begin(), + lines.end(), + offset, + [this](Index first, UnownedStringSlice second) + { return first < getLineStart(second); }); + line = Index(firstGreater - lines.begin()); + if (firstGreater == lines.begin()) { col = offset + 1; } else { - col = Index(offset - *(firstGreater - 1)); + col = Index(offset - getLineStart(lines[line-1])) + 1; } } + + // Get line from 1-based index. UnownedStringSlice getLine(Index lineIndex) { if (lineIndex < 0) return UnownedStringSlice(); - if (lineIndex - 1 >= lineBreaks.getCount()) + if (lineIndex - 1 >= lines.getCount()) return UnownedStringSlice(); - if (lineBreaks.getCount() == 0) + if (lines.getCount() == 0) return UnownedStringSlice(); - Int lineStart = lineIndex >= 2 ? lineBreaks[lineIndex - 2] : 0; - Int lineEnd = lineBreaks[lineIndex - 1]; - return text.getUnownedSlice().subString(lineStart, lineEnd); + return lineIndex > 0 ? lines[lineIndex - 1] : UnownedStringSlice(); } + + // Get length of an identifier token starting at the specified position. + int getTokenLength(Index line, Index col); }; struct DocumentDiagnostics @@ -72,11 +95,25 @@ namespace Slang String originalOutput; }; + + class SerializedModuleCache + : public RefObject + , public IModuleCache + { + public: + Dictionary<String, List<uint8_t>> serializedModules; + + void invalidate(const String& path) { serializedModules.Remove(path); } + virtual RefPtr<Module> tryLoadModule(Linkage* linkage, String filePath) override; + virtual void storeModule(Linkage* linkage, String filePath, RefPtr<Module> module) override; + }; + class WorkspaceVersion : public RefObject { private: Dictionary<String, Module*> modules; Dictionary<ModuleDecl*, RefPtr<ASTMarkup>> markupASTs; + void parseDiagnostics(String compilerOutput); public: Workspace* workspace; RefPtr<Linkage> linkage; @@ -86,7 +123,12 @@ namespace Slang Module* getOrLoadModule(String path); }; - + + struct OnwedPreprocessorMacroDefinition + { + String name; + String value; + }; class Workspace : public ISlangFileSystem , public ComObject @@ -97,10 +139,18 @@ namespace Slang public: List<String> rootDirectories; OrderedHashSet<String> searchPaths; + List<OnwedPreprocessorMacroDefinition> predefinedMacros; + SerializedModuleCache moduleCache; slang::IGlobalSession* slangGlobalSession; Dictionary<String, RefPtr<DocumentVersion>> openedDocuments; DocumentVersion* openDoc(String path, String text); + void changeDoc(const String& path, LanguageServerProtocol::Range range, const String& text); + void closeDoc(const String& path); + + // Update predefined macro settings. Returns true if the new settings are different from existing ones. + bool updatePredefinedMacros(List<String> predefinedMacros); + void init(List<URI> rootDirURI, slang::IGlobalSession* globalSession); void invalidate(); WorkspaceVersion* getCurrentVersion(); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 7602096d4..02b1efedf 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -2611,6 +2611,11 @@ void Linkage::loadParsedModule( } } loadedModulesList.add(loadedModule); + + if (m_moduleCache) + { + m_moduleCache->storeModule(this, pathInfo.foundPath, loadedModule); + } } Module* Linkage::loadModule(String const& name) @@ -2794,6 +2799,19 @@ RefPtr<Module> Linkage::findOrImportModule( if (mapPathToLoadedModule.TryGetValue(filePathInfo.getMostUniqueIdentity(), loadedModule)) return loadedModule; + // Is this module in user provided cache? + // (yong): module cache is intended to speed up language server reparsing. + // currently it is *not* enabled in language server. + if (m_moduleCache) + { + loadedModule = m_moduleCache->tryLoadModule(this, filePathInfo.foundPath); + if (loadedModule) + { + mapPathToLoadedModule[filePathInfo.getMostUniqueIdentity()] = loadedModule; + return loadedModule; + } + } + // Try to load it ComPtr<ISlangBlob> fileContents; if(SLANG_FAILED(includeSystem.loadFile(filePathInfo, fileContents))) @@ -2945,9 +2963,7 @@ static bool _canExportDeclSymbol(ASTNodeType type) { switch (type) { - case ASTNodeType::ModuleDecl: case ASTNodeType::EmptyDecl: - case ASTNodeType::NamespaceDecl: { return false; } |
