summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-26 22:40:06 -0700
committerGitHub <noreply@github.com>2023-04-26 22:40:06 -0700
commit7be108c379ccc7da3f46b30a2b5917104155d52b (patch)
treeb38b6d5483ba63b18d38c282c06dd55ff9c188ea
parent3acbe8145c60f4d1e7a180b4602a94269a489df5 (diff)
Intellisense: show info on decl kind and differentiability. (#2847)
-rw-r--r--source/slang/slang-ast-print.cpp4
-rw-r--r--source/slang/slang-language-server.cpp116
-rw-r--r--source/slang/slang-syntax.cpp11
-rw-r--r--source/slang/slang-syntax.h2
-rw-r--r--tests/autodiff/primal-substitute-3.slang1
-rw-r--r--tests/language-server/ordinary-comment-hover-info.slang.expected.txt2
-rw-r--r--tests/language-server/robustness-1.slang.expected.txt2
-rw-r--r--tests/language-server/robustness-3.slang.expected.txt2
-rw-r--r--tests/language-server/smoke.slang.expected.txt2
9 files changed, 137 insertions, 5 deletions
diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp
index b38b58385..23430cd13 100644
--- a/source/slang/slang-ast-print.cpp
+++ b/source/slang/slang-ast-print.cpp
@@ -168,7 +168,9 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth)
// If the parent declaration is a generic, then we need to print out its
// signature
- if (parentGenericDeclRef)
+ if (parentGenericDeclRef &&
+ !declRef.as<GenericValueParamDecl>() &&
+ !declRef.as<GenericTypeParamDecl>())
{
auto genSubst = as<GenericSubstitution>(declRef.substitutions.substitutions);
if (genSubst)
diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp
index a2cafa55a..b826d7cfe 100644
--- a/source/slang/slang-language-server.cpp
+++ b/source/slang/slang-language-server.cpp
@@ -216,6 +216,55 @@ static bool isBoolType(Type* t)
return basicType->baseType == BaseType::Bool;
}
+String getDeclKindString(DeclRef<Decl> declRef)
+{
+ if (declRef.as<ParamDecl>())
+ {
+ return "(parameter) ";
+ }
+ else if (declRef.as<GenericTypeParamDecl>())
+ {
+ return "(generic type parameter) ";
+ }
+ else if (declRef.as<GenericValueParamDecl>())
+ {
+ return "(generic value parameter) ";
+ }
+ else if (declRef.as<AttributeDecl>())
+ {
+ return "(attribute) ";
+ }
+ else if (auto varDecl = declRef.as<VarDeclBase>())
+ {
+ auto parent = declRef.getParent();
+ if (as<GenericDecl>(parent))
+ parent = parent.getParent();
+ if (parent.as<InterfaceDecl>())
+ {
+ return "(associated constant) ";
+ }
+ else if (parent.as<AggTypeDeclBase>())
+ {
+ return "(field) ";
+ }
+ const char* scopeKind = "";
+ if (parent.as<NamespaceDeclBase>())
+ scopeKind = "global ";
+ else if (getParentDecl(declRef.getDecl()))
+ scopeKind = "local ";
+ StringBuilder sb;
+ sb << "(";
+ sb << scopeKind;
+ if (varDecl.as<LetDecl>())
+ sb << "value";
+ else
+ sb << "variable";
+ sb << ") ";
+ return sb.produceString();
+ }
+ return String();
+}
+
String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version)
{
if (declRef.getDecl())
@@ -225,6 +274,7 @@ String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version)
astBuilder,
ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords |
ASTPrinter::OptionFlag::SimplifiedBuiltinType);
+ printer.getStringBuilder() << getDeclKindString(declRef);
printer.addDeclSignature(declRef);
if (auto varDecl = as<VarDeclBase>(declRef.getDecl()))
{
@@ -496,6 +546,72 @@ SlangResult LanguageServer::hover(
_tryGetDocumentation(sb, version, declRef.getDecl());
+ if (auto funcDecl = as<FunctionDeclBase>(declRef.getDecl()))
+ {
+ DiagnosticSink sink;
+ SharedSemanticsContext semanticContext(version->linkage, getModule(funcDecl), &sink);
+ SemanticsVisitor semanticsVisitor(&semanticContext);
+
+ auto assocDecls = semanticContext.getAssociatedDeclsForDecl(funcDecl);
+ Decl* bwdDiff = nullptr;
+ Decl* fwdDiff = nullptr;
+ Decl* primalSubst = nullptr;
+ auto getDeclFromExpr = [&](Expr* expr) -> Decl*
+ {
+ if (auto declRefExpr = as<DeclRefExpr>(expr))
+ return declRefExpr->declRef.getDecl();
+ return nullptr;
+ };
+ for (auto& assocDecl : assocDecls)
+ {
+ if (assocDecl->kind == DeclAssociationKind::ForwardDerivativeFunc)
+ fwdDiff = assocDecl->decl;
+ else if (assocDecl->kind == DeclAssociationKind::BackwardDerivativeFunc)
+ bwdDiff = assocDecl->decl;
+ else if (assocDecl->kind == DeclAssociationKind::PrimalSubstituteFunc)
+ primalSubst = assocDecl->decl;
+ }
+ bool isBackwardDifferentiable = false;
+ bool isForwardDifferentiable = false;
+ for (auto modifier : funcDecl->modifiers)
+ {
+ if (auto bwdDiffModifier = as<BackwardDerivativeAttribute>(modifier))
+ bwdDiff = getDeclFromExpr(bwdDiffModifier->funcExpr);
+ else if (auto fwdDiffModifier = as<ForwardDerivativeAttribute>(modifier))
+ fwdDiff = getDeclFromExpr(fwdDiffModifier->funcExpr);
+ else if (auto primalSubstModifier = as<PrimalSubstituteAttribute>(modifier))
+ primalSubst = getDeclFromExpr(primalSubstModifier->funcExpr);
+ else if (as<ForwardDifferentiableAttribute>(modifier))
+ isForwardDifferentiable = true;
+ else if (as<BackwardDifferentiableAttribute>(modifier))
+ isBackwardDifferentiable = true;
+ }
+ if (primalSubst)
+ {
+ for (auto modifier : primalSubst->modifiers)
+ {
+ if (as<ForwardDifferentiableAttribute>(modifier))
+ isForwardDifferentiable = true;
+ else if (as<BackwardDifferentiableAttribute>(modifier))
+ isBackwardDifferentiable = true;
+ }
+ }
+ if (isBackwardDifferentiable)
+ {
+ sb << "\nForward and backward differentiable\n\n";
+ }
+ if (isForwardDifferentiable)
+ {
+ sb << "\nForward differentiable\n\n";
+ }
+ if (fwdDiff && fwdDiff->getName())
+ sb << "Forward derivative: `" << fwdDiff->getName()->text << "`\n\n";
+ if (bwdDiff && bwdDiff->getName())
+ sb << "Backward derivative: `" << bwdDiff->getName()->text << "`\n\n";
+ if (primalSubst && primalSubst->getName())
+ sb << "Primal substitute: `" << primalSubst->getName()->text << "`\n\n";
+ }
+
auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(
declRef.getLoc(), SourceLocType::Actual);
appendDefinitionLocation(sb, m_workspace, humaneLoc);
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 11729800c..606c77096 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -1413,6 +1413,17 @@ Decl* getParentDecl(Decl* decl)
return decl;
}
+Decl* getParentFunc(Decl* decl)
+{
+ while (decl)
+ {
+ if (as<FunctionDeclBase>(decl))
+ return decl;
+ decl = decl->parentDecl;
+ }
+ return nullptr;
+}
+
static const ImageFormatInfo kImageFormatInfos[] =
{
#define SLANG_IMAGE_FORMAT_INFO(TYPE, COUNT, SIZE) SLANG_SCALAR_TYPE_##TYPE, uint8_t(COUNT), uint8_t(SIZE)
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index dd119cf3a..3ac258f08 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -337,6 +337,8 @@ namespace Slang
/// Get the parent decl, skipping any generic decls in between.
Decl* getParentDecl(Decl* decl);
+ Decl* getParentFunc(Decl* decl);
+
} // namespace Slang
#endif
diff --git a/tests/autodiff/primal-substitute-3.slang b/tests/autodiff/primal-substitute-3.slang
index ab2899bdc..9c67fbc04 100644
--- a/tests/autodiff/primal-substitute-3.slang
+++ b/tests/autodiff/primal-substitute-3.slang
@@ -18,6 +18,7 @@ struct A : IFoo
}
}
+// A normal function that calls doSomething().
float original<T : IFoo>(T p, float x)
{
p.doSomething();
diff --git a/tests/language-server/ordinary-comment-hover-info.slang.expected.txt b/tests/language-server/ordinary-comment-hover-info.slang.expected.txt
index d4fa0328a..72b83c696 100644
--- a/tests/language-server/ordinary-comment-hover-info.slang.expected.txt
+++ b/tests/language-server/ordinary-comment-hover-info.slang.expected.txt
@@ -2,7 +2,7 @@
range: 7,4 - 7,9
content:
```
-int value
+(global variable) int value
```
#1: Ordindary comment for `value`.
diff --git a/tests/language-server/robustness-1.slang.expected.txt b/tests/language-server/robustness-1.slang.expected.txt
index 4bfd5a31d..28ba6a292 100644
--- a/tests/language-server/robustness-1.slang.expected.txt
+++ b/tests/language-server/robustness-1.slang.expected.txt
@@ -2,7 +2,7 @@
range: 7,11 - 7,14
content:
```
-vector<float,3>[3] arr
+(local variable) vector<float,3>[3] arr
```
{REDACTED}.slang(8)
diff --git a/tests/language-server/robustness-3.slang.expected.txt b/tests/language-server/robustness-3.slang.expected.txt
index 8dcd04228..fe416c3de 100644
--- a/tests/language-server/robustness-3.slang.expected.txt
+++ b/tests/language-server/robustness-3.slang.expected.txt
@@ -2,7 +2,7 @@
range: 6,20 - 6,25
content:
```
-int index
+(parameter) int index
```
diff --git a/tests/language-server/smoke.slang.expected.txt b/tests/language-server/smoke.slang.expected.txt
index 0560c2a2d..0b37ed9db 100644
--- a/tests/language-server/smoke.slang.expected.txt
+++ b/tests/language-server/smoke.slang.expected.txt
@@ -6,7 +6,7 @@ getSum: 2 ,.;:()[]<>{}*&^%!-=+|/?
range: 24,26 - 24,31
content:
```
-Pair.T Pair<Pair.T, Pair.U>.first
+(field) Pair.T Pair<Pair.T, Pair.U>.first
```