summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-language-server.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-26 22:40:06 -0700
committerGitHub <noreply@github.com>2023-04-26 22:40:06 -0700
commit7be108c379ccc7da3f46b30a2b5917104155d52b (patch)
treeb38b6d5483ba63b18d38c282c06dd55ff9c188ea /source/slang/slang-language-server.cpp
parent3acbe8145c60f4d1e7a180b4602a94269a489df5 (diff)
Intellisense: show info on decl kind and differentiability. (#2847)
Diffstat (limited to 'source/slang/slang-language-server.cpp')
-rw-r--r--source/slang/slang-language-server.cpp116
1 files changed, 116 insertions, 0 deletions
diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp
index a2cafa55a..b826d7cfe 100644
--- a/source/slang/slang-language-server.cpp
+++ b/source/slang/slang-language-server.cpp
@@ -216,6 +216,55 @@ static bool isBoolType(Type* t)
return basicType->baseType == BaseType::Bool;
}
+String getDeclKindString(DeclRef<Decl> declRef)
+{
+ if (declRef.as<ParamDecl>())
+ {
+ return "(parameter) ";
+ }
+ else if (declRef.as<GenericTypeParamDecl>())
+ {
+ return "(generic type parameter) ";
+ }
+ else if (declRef.as<GenericValueParamDecl>())
+ {
+ return "(generic value parameter) ";
+ }
+ else if (declRef.as<AttributeDecl>())
+ {
+ return "(attribute) ";
+ }
+ else if (auto varDecl = declRef.as<VarDeclBase>())
+ {
+ auto parent = declRef.getParent();
+ if (as<GenericDecl>(parent))
+ parent = parent.getParent();
+ if (parent.as<InterfaceDecl>())
+ {
+ return "(associated constant) ";
+ }
+ else if (parent.as<AggTypeDeclBase>())
+ {
+ return "(field) ";
+ }
+ const char* scopeKind = "";
+ if (parent.as<NamespaceDeclBase>())
+ scopeKind = "global ";
+ else if (getParentDecl(declRef.getDecl()))
+ scopeKind = "local ";
+ StringBuilder sb;
+ sb << "(";
+ sb << scopeKind;
+ if (varDecl.as<LetDecl>())
+ sb << "value";
+ else
+ sb << "variable";
+ sb << ") ";
+ return sb.produceString();
+ }
+ return String();
+}
+
String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version)
{
if (declRef.getDecl())
@@ -225,6 +274,7 @@ String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version)
astBuilder,
ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords |
ASTPrinter::OptionFlag::SimplifiedBuiltinType);
+ printer.getStringBuilder() << getDeclKindString(declRef);
printer.addDeclSignature(declRef);
if (auto varDecl = as<VarDeclBase>(declRef.getDecl()))
{
@@ -496,6 +546,72 @@ SlangResult LanguageServer::hover(
_tryGetDocumentation(sb, version, declRef.getDecl());
+ if (auto funcDecl = as<FunctionDeclBase>(declRef.getDecl()))
+ {
+ DiagnosticSink sink;
+ SharedSemanticsContext semanticContext(version->linkage, getModule(funcDecl), &sink);
+ SemanticsVisitor semanticsVisitor(&semanticContext);
+
+ auto assocDecls = semanticContext.getAssociatedDeclsForDecl(funcDecl);
+ Decl* bwdDiff = nullptr;
+ Decl* fwdDiff = nullptr;
+ Decl* primalSubst = nullptr;
+ auto getDeclFromExpr = [&](Expr* expr) -> Decl*
+ {
+ if (auto declRefExpr = as<DeclRefExpr>(expr))
+ return declRefExpr->declRef.getDecl();
+ return nullptr;
+ };
+ for (auto& assocDecl : assocDecls)
+ {
+ if (assocDecl->kind == DeclAssociationKind::ForwardDerivativeFunc)
+ fwdDiff = assocDecl->decl;
+ else if (assocDecl->kind == DeclAssociationKind::BackwardDerivativeFunc)
+ bwdDiff = assocDecl->decl;
+ else if (assocDecl->kind == DeclAssociationKind::PrimalSubstituteFunc)
+ primalSubst = assocDecl->decl;
+ }
+ bool isBackwardDifferentiable = false;
+ bool isForwardDifferentiable = false;
+ for (auto modifier : funcDecl->modifiers)
+ {
+ if (auto bwdDiffModifier = as<BackwardDerivativeAttribute>(modifier))
+ bwdDiff = getDeclFromExpr(bwdDiffModifier->funcExpr);
+ else if (auto fwdDiffModifier = as<ForwardDerivativeAttribute>(modifier))
+ fwdDiff = getDeclFromExpr(fwdDiffModifier->funcExpr);
+ else if (auto primalSubstModifier = as<PrimalSubstituteAttribute>(modifier))
+ primalSubst = getDeclFromExpr(primalSubstModifier->funcExpr);
+ else if (as<ForwardDifferentiableAttribute>(modifier))
+ isForwardDifferentiable = true;
+ else if (as<BackwardDifferentiableAttribute>(modifier))
+ isBackwardDifferentiable = true;
+ }
+ if (primalSubst)
+ {
+ for (auto modifier : primalSubst->modifiers)
+ {
+ if (as<ForwardDifferentiableAttribute>(modifier))
+ isForwardDifferentiable = true;
+ else if (as<BackwardDifferentiableAttribute>(modifier))
+ isBackwardDifferentiable = true;
+ }
+ }
+ if (isBackwardDifferentiable)
+ {
+ sb << "\nForward and backward differentiable\n\n";
+ }
+ if (isForwardDifferentiable)
+ {
+ sb << "\nForward differentiable\n\n";
+ }
+ if (fwdDiff && fwdDiff->getName())
+ sb << "Forward derivative: `" << fwdDiff->getName()->text << "`\n\n";
+ if (bwdDiff && bwdDiff->getName())
+ sb << "Backward derivative: `" << bwdDiff->getName()->text << "`\n\n";
+ if (primalSubst && primalSubst->getName())
+ sb << "Primal substitute: `" << primalSubst->getName()->text << "`\n\n";
+ }
+
auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(
declRef.getLoc(), SourceLocType::Actual);
appendDefinitionLocation(sb, m_workspace, humaneLoc);