summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-16 16:08:51 -0800
committerGitHub <noreply@github.com>2022-11-16 16:08:51 -0800
commite13d38b6a281f444203410f09dab8b127e678975 (patch)
treee8db1272ee8a729256515cc11a635c3c68752004 /source
parent801aa3b44254341018a1acbe754f2ce3b0900e2a (diff)
Language server improvements for auto-diff. (#2521)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-iterator.h6
-rw-r--r--source/slang/slang-check-expr.cpp26
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp23
-rw-r--r--source/slang/slang-language-server.cpp249
-rw-r--r--source/slang/slang-language-server.h4
-rw-r--r--source/slang/slang-workspace-version.cpp13
-rw-r--r--source/slang/slang-workspace-version.h2
8 files changed, 229 insertions, 96 deletions
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index 257a910ea..ed396139e 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -262,6 +262,12 @@ struct ASTIterator
{
dispatchIfNotNull(expr->originalExpr);
}
+
+ void visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr)
+ {
+ iterator->maybeDispatchCallback(expr);
+ dispatchIfNotNull(expr->baseFunction);
+ }
};
struct ASTIteratorStmtVisitor : public StmtVisitor<ASTIteratorStmtVisitor>
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 311a5944b..b43a03150 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2115,9 +2115,14 @@ namespace Slang
resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType);
if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr)))
{
- if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>())
+ auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl();
+ if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl()))
{
- for (auto param : funcDecl.getDecl()->getParameters())
+ funcDecl = as<CallableDecl>(genDecl->inner);
+ }
+ if (funcDecl)
+ {
+ for (auto param : funcDecl->getParameters())
{
resultDiffExpr->newParameterNames.add(param->getName());
}
@@ -2144,14 +2149,19 @@ namespace Slang
resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType);
if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr)))
{
- if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>())
+ auto funcDecl = declRefExpr->declRef.as<CallableDecl>().getDecl();
+ if (auto genDecl = as<GenericDecl>(declRefExpr->declRef.getDecl()))
+ {
+ funcDecl = as<CallableDecl>(genDecl->inner);
+ }
+ if (funcDecl)
{
- for (auto param : funcDecl.getDecl()->getParameters())
+ for (auto param : funcDecl->getParameters())
{
resultDiffExpr->newParameterNames.add(param->getName());
}
+ resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient"));
}
- resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient"));
}
}
};
@@ -2175,13 +2185,15 @@ namespace Slang
{
auto lookupResultExpr = semantics->ConstructLookupResultExpr(item,
nullptr,
- expr->loc,
+ overloadedExpr->loc,
nullptr);
auto candidateExpr = actions->createDifferentiateExpr(semantics);
actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr);
+ candidateExpr->loc = expr->loc;
result->candidiateExprs.add(candidateExpr);
}
result->type.type = astBuilder->getOverloadedType();
+ result->loc = expr->loc;
return result;
}
else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr->baseFunction))
@@ -2191,9 +2203,11 @@ namespace Slang
{
auto candidateExpr = actions->createDifferentiateExpr(semantics);
actions->fillDifferentiateExpr(candidateExpr, semantics, item);
+ candidateExpr->loc = expr->loc;
result->candidiateExprs.add(candidateExpr);
}
result->type.type = astBuilder->getOverloadedType();
+ result->loc = expr->loc;
return result;
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 0dfd06923..c69a1e9e6 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -494,6 +494,8 @@ DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argu
DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)")
DIAGNOSTIC(38029, Error, typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'")
+DIAGNOSTIC(30830, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.")
+
DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself")
DIAGNOSTIC(39999, Error, errorInImportedModule, "import of module '$0' failed because of a compilation error")
DIAGNOSTIC(39999, Fatal, complationCeased, "compilation ceased")
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index cd211b0f5..4b6c7f33d 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -242,8 +242,11 @@ public:
}
bool visitOverloadedExpr(OverloadedExpr* expr)
{
- if (dispatchIfNotNull(expr->base))
- return true;
+ {
+ PushNode pushNode(context, expr);
+ if (dispatchIfNotNull(expr->base))
+ return true;
+ }
if (expr->lookupResult2.getName() &&
_isLocInRange(
context,
@@ -263,6 +266,7 @@ public:
if (dispatchIfNotNull(expr->base))
return true;
bool result = false;
+ PushNode pushNode(context, expr);
for (auto candidate : expr->candidiateExprs)
{
result |= dispatchIfNotNull(candidate);
@@ -408,7 +412,20 @@ public:
}
bool visitModifiedTypeExpr(ModifiedTypeExpr* expr) { return dispatchIfNotNull(expr->base.exp); }
bool visitTryExpr(TryExpr* expr) { return dispatchIfNotNull(expr->base); }
-
+ bool visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr)
+ {
+ auto humaneLoc = context->sourceManager->getHumaneLoc(expr->loc, SourceLocType::Actual);
+ auto tokenLen = context->doc->getTokenLength(humaneLoc.line, humaneLoc.column);
+ if (_isLocInRange(context, expr->loc, tokenLen))
+ {
+ ASTLookupResult result;
+ result.path = context->nodePath;
+ result.path.add(expr);
+ context->results.add(result);
+ return true;
+ }
+ return dispatchIfNotNull(expr->baseFunction);
+ }
};
struct ASTLookupStmtVisitor : public StmtVisitor<ASTLookupStmtVisitor, bool>
diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp
index 15217dcac..80d6e466b 100644
--- a/source/slang/slang-language-server.cpp
+++ b/source/slang/slang-language-server.cpp
@@ -10,7 +10,6 @@
#include <time.h>
#include "../core/slang-secure-crt.h"
-#include "../core/slang-range.h"
#include "../core/slang-char-util.h"
#include "../core/slang-string-util.h"
@@ -461,6 +460,29 @@ SlangResult LanguageServer::hover(
Hover hover;
auto leafNode = findResult[0].path.getLast();
+
+ auto maybeAppendAdditionalOverloadsHint = [&]()
+ {
+ Index numOverloads = 0;
+ for (Index i = findResult[0].path.getCount() - 1; i >= 0; i--)
+ {
+ auto node = findResult[0].path[i];
+ if (auto overloadExpr = as<OverloadedExpr>(node))
+ {
+ numOverloads = overloadExpr->lookupResult2.items.getCount();
+ }
+ else if (auto overloadedExpr2 = as<OverloadedExpr2>(node))
+ {
+ numOverloads = overloadedExpr2->candidiateExprs.getCount();
+ }
+ }
+ if (numOverloads > 1)
+ {
+ sb << "\n +" << numOverloads - 1 << " overload";
+ if (numOverloads > 2) sb << "s";
+ }
+ };
+
auto fillDeclRefHoverInfo = [&](DeclRef<Decl> declRef)
{
if (declRef.getDecl())
@@ -474,7 +496,7 @@ SlangResult LanguageServer::hover(
auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(
declRef.getLoc(), SourceLocType::Actual);
appendDefinitionLocation(sb, m_workspace, humaneLoc);
-
+ maybeAppendAdditionalOverloadsHint();
auto nodeHumaneLoc =
version->linkage->getSourceManager()->getHumaneLoc(leafNode->loc);
hover.range.start.line = int(nodeHumaneLoc.line - 1);
@@ -495,6 +517,29 @@ SlangResult LanguageServer::hover(
}
}
};
+ auto fillExprHoverInfo = [&](Expr* expr)
+ {
+ if (auto declRefExpr = as<DeclRefExpr>(expr))
+ return fillDeclRefHoverInfo(declRefExpr->declRef);
+ if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(expr))
+ {
+ String documentation;
+ String signature = getExprDeclSignature(expr, &documentation, nullptr);
+ if (signature.getLength() == 0)
+ return;
+ sb << "```\n"
+ << signature
+ << "\n```\n";
+ sb << documentation;
+ maybeAppendAdditionalOverloadsHint();
+ auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(
+ expr->loc, SourceLocType::Actual);
+ hover.range.start.line = int(humaneLoc.line - 1);
+ hover.range.end.line = int(humaneLoc.line - 1);
+ hover.range.start.character = int(humaneLoc.column - 1);
+ hover.range.end.character = hover.range.start.character + int(doc->getTokenLength(humaneLoc.line, humaneLoc.column));
+ }
+ };
if (auto declRefExpr = as<DeclRefExpr>(leafNode))
{
fillDeclRefHoverInfo(declRefExpr->declRef);
@@ -504,6 +549,18 @@ SlangResult LanguageServer::hover(
LookupResultItem& item = overloadedExpr->lookupResult2.item;
fillDeclRefHoverInfo(item.declRef);
}
+ else if (auto overloadedExpr2 = as<OverloadedExpr2>(leafNode))
+ {
+ if (overloadedExpr2->candidiateExprs.getCount() > 0)
+ {
+ auto candidateExpr = overloadedExpr2->candidiateExprs[0];
+ fillExprHoverInfo(candidateExpr);
+ }
+ }
+ else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(leafNode))
+ {
+ fillExprHoverInfo(higherOrderExpr);
+ }
else if (auto importDecl = as<ImportDecl>(leafNode))
{
auto moduleLoc = getModuleLoc(version->linkage->getSourceManager(), importDecl->importedModuleDecl);
@@ -871,6 +928,106 @@ SlangResult LanguageServer::semanticTokens(
return SLANG_OK;
}
+String LanguageServer::getExprDeclSignature(Expr* expr, String* outDocumentation, List<Slang::Range<Index>>* outParamRanges)
+{
+ if (auto declRefExpr = as<DeclRefExpr>(expr))
+ {
+ return getDeclRefSignature(declRefExpr->declRef, outDocumentation, outParamRanges);
+ }
+
+ auto higherOrderExpr = as<HigherOrderInvokeExpr>(expr);
+ auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderExpr));
+ if (!declRefExpr)
+ return String();
+ if (!declRefExpr->declRef.getDecl())
+ return String();
+ auto funcType = as<FuncType>(higherOrderExpr->type);
+ if (!funcType)
+ return String();
+
+ auto version = m_workspace->getCurrentVersion();
+
+ SignatureInformation sigInfo;
+
+ ASTPrinter printer(
+ version->linkage->getASTBuilder(),
+ ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords |
+ ASTPrinter::OptionFlag::SimplifiedBuiltinType);
+
+ printer.addDeclKindPrefix(declRefExpr->declRef.getDecl());
+ auto inner = higherOrderExpr;
+ int closingParentCount = 0;
+ while (inner)
+ {
+ printer.getStringBuilder() << getHigherOrderOperatorName(inner) << "(";
+ closingParentCount++;
+ inner = as<HigherOrderInvokeExpr>(inner->baseFunction);
+ }
+ printer.addDeclPath(declRefExpr->declRef);
+ for (int i = 0; i < closingParentCount; i++)
+ printer.getStringBuilder() << ")";
+ bool isFirst = true;
+ printer.getStringBuilder() << "(";
+ int paramIndex = 0;
+ for (auto param : funcType->paramTypes)
+ {
+ if (!isFirst)
+ printer.getStringBuilder() << ", ";
+ Slang::Range<Index> range;
+ range.begin = printer.getStringBuilder().getLength();
+ if (paramIndex < higherOrderExpr->newParameterNames.getCount())
+ {
+ if (higherOrderExpr->newParameterNames[paramIndex])
+ {
+ printer.getStringBuilder() << higherOrderExpr->newParameterNames[paramIndex]->text << ": ";
+ }
+ }
+ printer.addType(param);
+ range.end = printer.getStringBuilder().getLength();
+ if (outParamRanges)
+ outParamRanges->add(range);
+ isFirst = false;
+ paramIndex++;
+ }
+ printer.getStringBuilder() << ") -> ";
+ printer.addType(funcType->getResultType());
+
+ if (outDocumentation)
+ {
+ StringBuilder docSB;
+ auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(declRefExpr->declRef.getLoc(), SourceLocType::Actual);
+ _tryGetDocumentation(docSB, version, declRefExpr->declRef.getDecl());
+ appendDefinitionLocation(docSB, m_workspace, humaneLoc);
+ *outDocumentation = docSB.ProduceString();
+ }
+
+ return printer.getString();
+}
+
+String LanguageServer::getDeclRefSignature(DeclRef<Decl> declRef, String* outDocumentation, List<Slang::Range<Index>>* outParamRanges)
+{
+ auto version = m_workspace->getCurrentVersion();
+ ASTPrinter printer(
+ version->linkage->getASTBuilder(),
+ ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords |
+ ASTPrinter::OptionFlag::SimplifiedBuiltinType);
+
+ printer.addDeclKindPrefix(declRef.getDecl());
+ printer.addDeclPath(declRef);
+ printer.addDeclParams(declRef, outParamRanges);
+ printer.addDeclResultType(declRef);
+
+ if (outDocumentation)
+ {
+ StringBuilder docSB;
+ auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(declRef.getLoc(), SourceLocType::Actual);
+ _tryGetDocumentation(docSB, version, declRef.getDecl());
+ appendDefinitionLocation(docSB, m_workspace, humaneLoc);
+ *outDocumentation = docSB.ProduceString();
+ }
+ return printer.getString();
+}
+
SlangResult LanguageServer::signatureHelp(
const LanguageServerProtocol::SignatureHelpParams& args, const JSONValue& responseId)
{
@@ -958,23 +1115,9 @@ SlangResult LanguageServer::signatureHelp(
SignatureInformation sigInfo;
List<Slang::Range<Index>> paramRanges;
- ASTPrinter printer(
- version->linkage->getASTBuilder(),
- ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords |
- ASTPrinter::OptionFlag::SimplifiedBuiltinType);
-
- printer.addDeclKindPrefix(declRef.getDecl());
- printer.addDeclPath(declRef);
- printer.addDeclParams(declRef, &paramRanges);
- printer.addDeclResultType(declRef);
-
- sigInfo.label = printer.getString();
-
- StringBuilder docSB;
- auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(declRef.getLoc(), SourceLocType::Actual);
- _tryGetDocumentation(docSB, version, declRef.getDecl());
- appendDefinitionLocation(docSB, m_workspace, humaneLoc);
- sigInfo.documentation.value = docSB.ProduceString();
+ String documentation;
+ sigInfo.label = getDeclRefSignature(declRef, &documentation, &paramRanges);
+ sigInfo.documentation.value = documentation;
sigInfo.documentation.kind = "markdown";
for (auto& range : paramRanges)
@@ -989,72 +1132,14 @@ SlangResult LanguageServer::signatureHelp(
auto addExpr = [&](Expr* expr)
{
- auto higherOrderExpr = as<HigherOrderInvokeExpr>(expr);
- if (!higherOrderExpr)
- return;
- auto funcType = as<FuncType>(higherOrderExpr->type);
- if (!funcType)
- return;
- auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderExpr));
- if (!declRefExpr)
- return;
- if (!declRefExpr->declRef.getDecl())
- return;
-
SignatureInformation sigInfo;
-
List<Slang::Range<Index>> paramRanges;
- ASTPrinter printer(
- version->linkage->getASTBuilder(),
- ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords |
- ASTPrinter::OptionFlag::SimplifiedBuiltinType);
-
- printer.addDeclKindPrefix(declRefExpr->declRef.getDecl());
- auto inner = higherOrderExpr;
- int closingParentCount = 0;
- while (inner)
- {
- printer.getStringBuilder() << getHigherOrderOperatorName(inner) << "(";
- closingParentCount++;
- inner = as<HigherOrderInvokeExpr>(inner->baseFunction);
- }
- printer.addDeclPath(declRefExpr->declRef);
- for (int i = 0; i < closingParentCount; i++)
- printer.getStringBuilder() << ")";
- bool isFirst = true;
- printer.getStringBuilder() << "(";
- int paramIndex = 0;
- for (auto param : funcType->paramTypes)
- {
- if (!isFirst)
- printer.getStringBuilder() << ", ";
- Slang::Range<Index> range;
- range.begin = printer.getStringBuilder().getLength();
- if (paramIndex < higherOrderExpr->newParameterNames.getCount())
- {
- if (higherOrderExpr->newParameterNames[paramIndex])
- {
- printer.getStringBuilder() << higherOrderExpr->newParameterNames[paramIndex]->text << ": ";
- }
- }
- printer.addType(param);
- range.end = printer.getStringBuilder().getLength();
- paramRanges.add(range);
- isFirst = false;
- paramIndex++;
- }
- printer.getStringBuilder() << ") -> ";
- printer.addType(funcType->getResultType());
-
- sigInfo.label = printer.getString();
-
- StringBuilder docSB;
- auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(declRefExpr->declRef.getLoc(), SourceLocType::Actual);
- _tryGetDocumentation(docSB, version, declRefExpr->declRef.getDecl());
- appendDefinitionLocation(docSB, m_workspace, humaneLoc);
- sigInfo.documentation.value = docSB.ProduceString();
+ String documentation;
+ sigInfo.label = getExprDeclSignature(expr, &documentation, &paramRanges);
+ if (sigInfo.label.getLength() == 0)
+ return;
+ sigInfo.documentation.value = documentation;
sigInfo.documentation.kind = "markdown";
-
for (auto& range : paramRanges)
{
ParameterInformation paramInfo;
diff --git a/source/slang/slang-language-server.h b/source/slang/slang-language-server.h
index 8ece0cc41..07c7fcfaa 100644
--- a/source/slang/slang-language-server.h
+++ b/source/slang/slang-language-server.h
@@ -1,6 +1,7 @@
#pragma once
#include <chrono>
#include "../../slang.h"
+#include "../core/slang-range.h"
#include "../compiler-core/slang-json-rpc.h"
#include "../compiler-core/slang-json-rpc-connection.h"
#include "slang-workspace-version.h"
@@ -131,7 +132,8 @@ public:
const LanguageServerProtocol::DocumentRangeFormattingParams& args, const JSONValue& responseId);
SlangResult onTypeFormatting(
const LanguageServerProtocol::DocumentOnTypeFormattingParams& args, const JSONValue& responseId);
-
+ String getExprDeclSignature(Expr* expr, String* outDocumentation, List<Slang::Range<Index>>* outParamRanges);
+ String getDeclRefSignature(DeclRef<Decl> declRef, String* outDocumentation, List<Slang::Range<Index>>* outParamRanges);
private:
SlangResult parseNextMessage();
slang::IGlobalSession* getOrCreateGlobalSession();
diff --git a/source/slang/slang-workspace-version.cpp b/source/slang/slang-workspace-version.cpp
index 914393320..d5fc62e79 100644
--- a/source/slang/slang-workspace-version.cpp
+++ b/source/slang/slang-workspace-version.cpp
@@ -458,20 +458,25 @@ UnownedStringSlice DocumentVersion::peekIdentifier(Index& offset)
return UnownedStringSlice("");
}
-
-int DocumentVersion::getTokenLength(Index line, Index col)
+int DocumentVersion::getTokenLength(Index offset)
{
- auto offset = getOffset(line, col);
if (offset >= 0)
{
Index pos = offset;
for (; pos < text.getLength() && _isIdentifierChar(text[pos]); ++pos)
- {}
+ {
+ }
return (int)(pos - offset);
}
return 0;
}
+int DocumentVersion::getTokenLength(Index line, Index col)
+{
+ auto offset = getOffset(line, col);
+ return getTokenLength(offset);
+}
+
ASTMarkup* WorkspaceVersion::getOrCreateMarkupAST(ModuleDecl* module)
{
RefPtr<ASTMarkup> astMarkup;
diff --git a/source/slang/slang-workspace-version.h b/source/slang/slang-workspace-version.h
index 927bf46a8..bab9c963d 100644
--- a/source/slang/slang-workspace-version.h
+++ b/source/slang/slang-workspace-version.h
@@ -98,6 +98,8 @@ namespace Slang
// Get length of an identifier token starting at the specified position.
int getTokenLength(Index line, Index col);
+ int getTokenLength(Index offset);
+
};
struct DocumentDiagnostics