From e13d38b6a281f444203410f09dab8b127e678975 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 16 Nov 2022 16:08:51 -0800 Subject: Language server improvements for auto-diff. (#2521) --- source/slang/slang-ast-iterator.h | 6 + source/slang/slang-check-expr.cpp | 26 ++- source/slang/slang-diagnostic-defs.h | 2 + source/slang/slang-language-server-ast-lookup.cpp | 23 +- source/slang/slang-language-server.cpp | 249 +++++++++++++++------- source/slang/slang-language-server.h | 4 +- source/slang/slang-workspace-version.cpp | 13 +- source/slang/slang-workspace-version.h | 2 + 8 files changed, 229 insertions(+), 96 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 257a910ea..ed396139e 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -262,6 +262,12 @@ struct ASTIterator { dispatchIfNotNull(expr->originalExpr); } + + void visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr) + { + iterator->maybeDispatchCallback(expr); + dispatchIfNotNull(expr->baseFunction); + } }; struct ASTIteratorStmtVisitor : public StmtVisitor diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 311a5944b..b43a03150 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2115,9 +2115,14 @@ namespace Slang resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType); if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - if (auto funcDecl = declRefExpr->declRef.as()) + auto funcDecl = declRefExpr->declRef.as().getDecl(); + if (auto genDecl = as(declRefExpr->declRef.getDecl())) { - for (auto param : funcDecl.getDecl()->getParameters()) + funcDecl = as(genDecl->inner); + } + if (funcDecl) + { + for (auto param : funcDecl->getParameters()) { resultDiffExpr->newParameterNames.add(param->getName()); } @@ -2144,14 +2149,19 @@ namespace Slang resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType); if (auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(funcExpr))) { - if (auto funcDecl = declRefExpr->declRef.as()) + auto funcDecl = declRefExpr->declRef.as().getDecl(); + if (auto genDecl = as(declRefExpr->declRef.getDecl())) + { + funcDecl = as(genDecl->inner); + } + if (funcDecl) { - for (auto param : funcDecl.getDecl()->getParameters()) + for (auto param : funcDecl->getParameters()) { resultDiffExpr->newParameterNames.add(param->getName()); } + resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); } - resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); } } }; @@ -2175,13 +2185,15 @@ namespace Slang { auto lookupResultExpr = semantics->ConstructLookupResultExpr(item, nullptr, - expr->loc, + overloadedExpr->loc, nullptr); auto candidateExpr = actions->createDifferentiateExpr(semantics); actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr); + candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } result->type.type = astBuilder->getOverloadedType(); + result->loc = expr->loc; return result; } else if (auto overloadedExpr2 = as(expr->baseFunction)) @@ -2191,9 +2203,11 @@ namespace Slang { auto candidateExpr = actions->createDifferentiateExpr(semantics); actions->fillDifferentiateExpr(candidateExpr, semantics, item); + candidateExpr->loc = expr->loc; result->candidiateExprs.add(candidateExpr); } result->type.type = astBuilder->getOverloadedType(); + result->loc = expr->loc; return result; } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 0dfd06923..c69a1e9e6 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -494,6 +494,8 @@ DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argu DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)") DIAGNOSTIC(38029, Error, typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'") +DIAGNOSTIC(30830, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.") + DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC(39999, Error, errorInImportedModule, "import of module '$0' failed because of a compilation error") DIAGNOSTIC(39999, Fatal, complationCeased, "compilation ceased") diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index cd211b0f5..4b6c7f33d 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -242,8 +242,11 @@ public: } bool visitOverloadedExpr(OverloadedExpr* expr) { - if (dispatchIfNotNull(expr->base)) - return true; + { + PushNode pushNode(context, expr); + if (dispatchIfNotNull(expr->base)) + return true; + } if (expr->lookupResult2.getName() && _isLocInRange( context, @@ -263,6 +266,7 @@ public: if (dispatchIfNotNull(expr->base)) return true; bool result = false; + PushNode pushNode(context, expr); for (auto candidate : expr->candidiateExprs) { result |= dispatchIfNotNull(candidate); @@ -408,7 +412,20 @@ public: } bool visitModifiedTypeExpr(ModifiedTypeExpr* expr) { return dispatchIfNotNull(expr->base.exp); } bool visitTryExpr(TryExpr* expr) { return dispatchIfNotNull(expr->base); } - + bool visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr) + { + auto humaneLoc = context->sourceManager->getHumaneLoc(expr->loc, SourceLocType::Actual); + auto tokenLen = context->doc->getTokenLength(humaneLoc.line, humaneLoc.column); + if (_isLocInRange(context, expr->loc, tokenLen)) + { + ASTLookupResult result; + result.path = context->nodePath; + result.path.add(expr); + context->results.add(result); + return true; + } + return dispatchIfNotNull(expr->baseFunction); + } }; struct ASTLookupStmtVisitor : public StmtVisitor diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index 15217dcac..80d6e466b 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -10,7 +10,6 @@ #include #include "../core/slang-secure-crt.h" -#include "../core/slang-range.h" #include "../core/slang-char-util.h" #include "../core/slang-string-util.h" @@ -461,6 +460,29 @@ SlangResult LanguageServer::hover( Hover hover; auto leafNode = findResult[0].path.getLast(); + + auto maybeAppendAdditionalOverloadsHint = [&]() + { + Index numOverloads = 0; + for (Index i = findResult[0].path.getCount() - 1; i >= 0; i--) + { + auto node = findResult[0].path[i]; + if (auto overloadExpr = as(node)) + { + numOverloads = overloadExpr->lookupResult2.items.getCount(); + } + else if (auto overloadedExpr2 = as(node)) + { + numOverloads = overloadedExpr2->candidiateExprs.getCount(); + } + } + if (numOverloads > 1) + { + sb << "\n +" << numOverloads - 1 << " overload"; + if (numOverloads > 2) sb << "s"; + } + }; + auto fillDeclRefHoverInfo = [&](DeclRef declRef) { if (declRef.getDecl()) @@ -474,7 +496,7 @@ SlangResult LanguageServer::hover( auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc( declRef.getLoc(), SourceLocType::Actual); appendDefinitionLocation(sb, m_workspace, humaneLoc); - + maybeAppendAdditionalOverloadsHint(); auto nodeHumaneLoc = version->linkage->getSourceManager()->getHumaneLoc(leafNode->loc); hover.range.start.line = int(nodeHumaneLoc.line - 1); @@ -495,6 +517,29 @@ SlangResult LanguageServer::hover( } } }; + auto fillExprHoverInfo = [&](Expr* expr) + { + if (auto declRefExpr = as(expr)) + return fillDeclRefHoverInfo(declRefExpr->declRef); + if (auto higherOrderExpr = as(expr)) + { + String documentation; + String signature = getExprDeclSignature(expr, &documentation, nullptr); + if (signature.getLength() == 0) + return; + sb << "```\n" + << signature + << "\n```\n"; + sb << documentation; + maybeAppendAdditionalOverloadsHint(); + auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc( + expr->loc, SourceLocType::Actual); + hover.range.start.line = int(humaneLoc.line - 1); + hover.range.end.line = int(humaneLoc.line - 1); + hover.range.start.character = int(humaneLoc.column - 1); + hover.range.end.character = hover.range.start.character + int(doc->getTokenLength(humaneLoc.line, humaneLoc.column)); + } + }; if (auto declRefExpr = as(leafNode)) { fillDeclRefHoverInfo(declRefExpr->declRef); @@ -504,6 +549,18 @@ SlangResult LanguageServer::hover( LookupResultItem& item = overloadedExpr->lookupResult2.item; fillDeclRefHoverInfo(item.declRef); } + else if (auto overloadedExpr2 = as(leafNode)) + { + if (overloadedExpr2->candidiateExprs.getCount() > 0) + { + auto candidateExpr = overloadedExpr2->candidiateExprs[0]; + fillExprHoverInfo(candidateExpr); + } + } + else if (auto higherOrderExpr = as(leafNode)) + { + fillExprHoverInfo(higherOrderExpr); + } else if (auto importDecl = as(leafNode)) { auto moduleLoc = getModuleLoc(version->linkage->getSourceManager(), importDecl->importedModuleDecl); @@ -871,6 +928,106 @@ SlangResult LanguageServer::semanticTokens( return SLANG_OK; } +String LanguageServer::getExprDeclSignature(Expr* expr, String* outDocumentation, List>* outParamRanges) +{ + if (auto declRefExpr = as(expr)) + { + return getDeclRefSignature(declRefExpr->declRef, outDocumentation, outParamRanges); + } + + auto higherOrderExpr = as(expr); + auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(higherOrderExpr)); + if (!declRefExpr) + return String(); + if (!declRefExpr->declRef.getDecl()) + return String(); + auto funcType = as(higherOrderExpr->type); + if (!funcType) + return String(); + + auto version = m_workspace->getCurrentVersion(); + + SignatureInformation sigInfo; + + ASTPrinter printer( + version->linkage->getASTBuilder(), + ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | + ASTPrinter::OptionFlag::SimplifiedBuiltinType); + + printer.addDeclKindPrefix(declRefExpr->declRef.getDecl()); + auto inner = higherOrderExpr; + int closingParentCount = 0; + while (inner) + { + printer.getStringBuilder() << getHigherOrderOperatorName(inner) << "("; + closingParentCount++; + inner = as(inner->baseFunction); + } + printer.addDeclPath(declRefExpr->declRef); + for (int i = 0; i < closingParentCount; i++) + printer.getStringBuilder() << ")"; + bool isFirst = true; + printer.getStringBuilder() << "("; + int paramIndex = 0; + for (auto param : funcType->paramTypes) + { + if (!isFirst) + printer.getStringBuilder() << ", "; + Slang::Range range; + range.begin = printer.getStringBuilder().getLength(); + if (paramIndex < higherOrderExpr->newParameterNames.getCount()) + { + if (higherOrderExpr->newParameterNames[paramIndex]) + { + printer.getStringBuilder() << higherOrderExpr->newParameterNames[paramIndex]->text << ": "; + } + } + printer.addType(param); + range.end = printer.getStringBuilder().getLength(); + if (outParamRanges) + outParamRanges->add(range); + isFirst = false; + paramIndex++; + } + printer.getStringBuilder() << ") -> "; + printer.addType(funcType->getResultType()); + + if (outDocumentation) + { + StringBuilder docSB; + auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(declRefExpr->declRef.getLoc(), SourceLocType::Actual); + _tryGetDocumentation(docSB, version, declRefExpr->declRef.getDecl()); + appendDefinitionLocation(docSB, m_workspace, humaneLoc); + *outDocumentation = docSB.ProduceString(); + } + + return printer.getString(); +} + +String LanguageServer::getDeclRefSignature(DeclRef declRef, String* outDocumentation, List>* outParamRanges) +{ + auto version = m_workspace->getCurrentVersion(); + ASTPrinter printer( + version->linkage->getASTBuilder(), + ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | + ASTPrinter::OptionFlag::SimplifiedBuiltinType); + + printer.addDeclKindPrefix(declRef.getDecl()); + printer.addDeclPath(declRef); + printer.addDeclParams(declRef, outParamRanges); + printer.addDeclResultType(declRef); + + if (outDocumentation) + { + StringBuilder docSB; + auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(declRef.getLoc(), SourceLocType::Actual); + _tryGetDocumentation(docSB, version, declRef.getDecl()); + appendDefinitionLocation(docSB, m_workspace, humaneLoc); + *outDocumentation = docSB.ProduceString(); + } + return printer.getString(); +} + SlangResult LanguageServer::signatureHelp( const LanguageServerProtocol::SignatureHelpParams& args, const JSONValue& responseId) { @@ -958,23 +1115,9 @@ SlangResult LanguageServer::signatureHelp( SignatureInformation sigInfo; List> paramRanges; - ASTPrinter printer( - version->linkage->getASTBuilder(), - ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | - ASTPrinter::OptionFlag::SimplifiedBuiltinType); - - printer.addDeclKindPrefix(declRef.getDecl()); - printer.addDeclPath(declRef); - printer.addDeclParams(declRef, ¶mRanges); - printer.addDeclResultType(declRef); - - sigInfo.label = printer.getString(); - - StringBuilder docSB; - auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(declRef.getLoc(), SourceLocType::Actual); - _tryGetDocumentation(docSB, version, declRef.getDecl()); - appendDefinitionLocation(docSB, m_workspace, humaneLoc); - sigInfo.documentation.value = docSB.ProduceString(); + String documentation; + sigInfo.label = getDeclRefSignature(declRef, &documentation, ¶mRanges); + sigInfo.documentation.value = documentation; sigInfo.documentation.kind = "markdown"; for (auto& range : paramRanges) @@ -989,72 +1132,14 @@ SlangResult LanguageServer::signatureHelp( auto addExpr = [&](Expr* expr) { - auto higherOrderExpr = as(expr); - if (!higherOrderExpr) - return; - auto funcType = as(higherOrderExpr->type); - if (!funcType) - return; - auto declRefExpr = as(getInnerMostExprFromHigherOrderExpr(higherOrderExpr)); - if (!declRefExpr) - return; - if (!declRefExpr->declRef.getDecl()) - return; - SignatureInformation sigInfo; - List> paramRanges; - ASTPrinter printer( - version->linkage->getASTBuilder(), - ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | - ASTPrinter::OptionFlag::SimplifiedBuiltinType); - - printer.addDeclKindPrefix(declRefExpr->declRef.getDecl()); - auto inner = higherOrderExpr; - int closingParentCount = 0; - while (inner) - { - printer.getStringBuilder() << getHigherOrderOperatorName(inner) << "("; - closingParentCount++; - inner = as(inner->baseFunction); - } - printer.addDeclPath(declRefExpr->declRef); - for (int i = 0; i < closingParentCount; i++) - printer.getStringBuilder() << ")"; - bool isFirst = true; - printer.getStringBuilder() << "("; - int paramIndex = 0; - for (auto param : funcType->paramTypes) - { - if (!isFirst) - printer.getStringBuilder() << ", "; - Slang::Range range; - range.begin = printer.getStringBuilder().getLength(); - if (paramIndex < higherOrderExpr->newParameterNames.getCount()) - { - if (higherOrderExpr->newParameterNames[paramIndex]) - { - printer.getStringBuilder() << higherOrderExpr->newParameterNames[paramIndex]->text << ": "; - } - } - printer.addType(param); - range.end = printer.getStringBuilder().getLength(); - paramRanges.add(range); - isFirst = false; - paramIndex++; - } - printer.getStringBuilder() << ") -> "; - printer.addType(funcType->getResultType()); - - sigInfo.label = printer.getString(); - - StringBuilder docSB; - auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(declRefExpr->declRef.getLoc(), SourceLocType::Actual); - _tryGetDocumentation(docSB, version, declRefExpr->declRef.getDecl()); - appendDefinitionLocation(docSB, m_workspace, humaneLoc); - sigInfo.documentation.value = docSB.ProduceString(); + String documentation; + sigInfo.label = getExprDeclSignature(expr, &documentation, ¶mRanges); + if (sigInfo.label.getLength() == 0) + return; + sigInfo.documentation.value = documentation; sigInfo.documentation.kind = "markdown"; - for (auto& range : paramRanges) { ParameterInformation paramInfo; diff --git a/source/slang/slang-language-server.h b/source/slang/slang-language-server.h index 8ece0cc41..07c7fcfaa 100644 --- a/source/slang/slang-language-server.h +++ b/source/slang/slang-language-server.h @@ -1,6 +1,7 @@ #pragma once #include #include "../../slang.h" +#include "../core/slang-range.h" #include "../compiler-core/slang-json-rpc.h" #include "../compiler-core/slang-json-rpc-connection.h" #include "slang-workspace-version.h" @@ -131,7 +132,8 @@ public: const LanguageServerProtocol::DocumentRangeFormattingParams& args, const JSONValue& responseId); SlangResult onTypeFormatting( const LanguageServerProtocol::DocumentOnTypeFormattingParams& args, const JSONValue& responseId); - + String getExprDeclSignature(Expr* expr, String* outDocumentation, List>* outParamRanges); + String getDeclRefSignature(DeclRef declRef, String* outDocumentation, List>* outParamRanges); private: SlangResult parseNextMessage(); slang::IGlobalSession* getOrCreateGlobalSession(); diff --git a/source/slang/slang-workspace-version.cpp b/source/slang/slang-workspace-version.cpp index 914393320..d5fc62e79 100644 --- a/source/slang/slang-workspace-version.cpp +++ b/source/slang/slang-workspace-version.cpp @@ -458,20 +458,25 @@ UnownedStringSlice DocumentVersion::peekIdentifier(Index& offset) return UnownedStringSlice(""); } - -int DocumentVersion::getTokenLength(Index line, Index col) +int DocumentVersion::getTokenLength(Index offset) { - auto offset = getOffset(line, col); if (offset >= 0) { Index pos = offset; for (; pos < text.getLength() && _isIdentifierChar(text[pos]); ++pos) - {} + { + } return (int)(pos - offset); } return 0; } +int DocumentVersion::getTokenLength(Index line, Index col) +{ + auto offset = getOffset(line, col); + return getTokenLength(offset); +} + ASTMarkup* WorkspaceVersion::getOrCreateMarkupAST(ModuleDecl* module) { RefPtr astMarkup; diff --git a/source/slang/slang-workspace-version.h b/source/slang/slang-workspace-version.h index 927bf46a8..bab9c963d 100644 --- a/source/slang/slang-workspace-version.h +++ b/source/slang/slang-workspace-version.h @@ -98,6 +98,8 @@ namespace Slang // Get length of an identifier token starting at the specified position. int getTokenLength(Index line, Index col); + int getTokenLength(Index offset); + }; struct DocumentDiagnostics -- cgit v1.2.3