From 7be108c379ccc7da3f46b30a2b5917104155d52b Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 26 Apr 2023 22:40:06 -0700 Subject: Intellisense: show info on decl kind and differentiability. (#2847) --- source/slang/slang-ast-print.cpp | 4 +- source/slang/slang-language-server.cpp | 116 +++++++++++++++++++++ source/slang/slang-syntax.cpp | 11 ++ source/slang/slang-syntax.h | 2 + tests/autodiff/primal-substitute-3.slang | 1 + .../ordinary-comment-hover-info.slang.expected.txt | 2 +- .../robustness-1.slang.expected.txt | 2 +- .../robustness-3.slang.expected.txt | 2 +- tests/language-server/smoke.slang.expected.txt | 2 +- 9 files changed, 137 insertions(+), 5 deletions(-) diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index b38b58385..23430cd13 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -168,7 +168,9 @@ void ASTPrinter::_addDeclPathRec(const DeclRef& declRef, Index depth) // If the parent declaration is a generic, then we need to print out its // signature - if (parentGenericDeclRef) + if (parentGenericDeclRef && + !declRef.as() && + !declRef.as()) { auto genSubst = as(declRef.substitutions.substitutions); if (genSubst) diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index a2cafa55a..b826d7cfe 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -216,6 +216,55 @@ static bool isBoolType(Type* t) return basicType->baseType == BaseType::Bool; } +String getDeclKindString(DeclRef declRef) +{ + if (declRef.as()) + { + return "(parameter) "; + } + else if (declRef.as()) + { + return "(generic type parameter) "; + } + else if (declRef.as()) + { + return "(generic value parameter) "; + } + else if (declRef.as()) + { + return "(attribute) "; + } + else if (auto varDecl = declRef.as()) + { + auto parent = declRef.getParent(); + if (as(parent)) + parent = parent.getParent(); + if (parent.as()) + { + return "(associated constant) "; + } + else if (parent.as()) + { + return "(field) "; + } + const char* scopeKind = ""; + if (parent.as()) + scopeKind = "global "; + else if (getParentDecl(declRef.getDecl())) + scopeKind = "local "; + StringBuilder sb; + sb << "("; + sb << scopeKind; + if (varDecl.as()) + sb << "value"; + else + sb << "variable"; + sb << ") "; + return sb.produceString(); + } + return String(); +} + String getDeclSignatureString(DeclRef declRef, WorkspaceVersion* version) { if (declRef.getDecl()) @@ -225,6 +274,7 @@ String getDeclSignatureString(DeclRef declRef, WorkspaceVersion* version) astBuilder, ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | ASTPrinter::OptionFlag::SimplifiedBuiltinType); + printer.getStringBuilder() << getDeclKindString(declRef); printer.addDeclSignature(declRef); if (auto varDecl = as(declRef.getDecl())) { @@ -496,6 +546,72 @@ SlangResult LanguageServer::hover( _tryGetDocumentation(sb, version, declRef.getDecl()); + if (auto funcDecl = as(declRef.getDecl())) + { + DiagnosticSink sink; + SharedSemanticsContext semanticContext(version->linkage, getModule(funcDecl), &sink); + SemanticsVisitor semanticsVisitor(&semanticContext); + + auto assocDecls = semanticContext.getAssociatedDeclsForDecl(funcDecl); + Decl* bwdDiff = nullptr; + Decl* fwdDiff = nullptr; + Decl* primalSubst = nullptr; + auto getDeclFromExpr = [&](Expr* expr) -> Decl* + { + if (auto declRefExpr = as(expr)) + return declRefExpr->declRef.getDecl(); + return nullptr; + }; + for (auto& assocDecl : assocDecls) + { + if (assocDecl->kind == DeclAssociationKind::ForwardDerivativeFunc) + fwdDiff = assocDecl->decl; + else if (assocDecl->kind == DeclAssociationKind::BackwardDerivativeFunc) + bwdDiff = assocDecl->decl; + else if (assocDecl->kind == DeclAssociationKind::PrimalSubstituteFunc) + primalSubst = assocDecl->decl; + } + bool isBackwardDifferentiable = false; + bool isForwardDifferentiable = false; + for (auto modifier : funcDecl->modifiers) + { + if (auto bwdDiffModifier = as(modifier)) + bwdDiff = getDeclFromExpr(bwdDiffModifier->funcExpr); + else if (auto fwdDiffModifier = as(modifier)) + fwdDiff = getDeclFromExpr(fwdDiffModifier->funcExpr); + else if (auto primalSubstModifier = as(modifier)) + primalSubst = getDeclFromExpr(primalSubstModifier->funcExpr); + else if (as(modifier)) + isForwardDifferentiable = true; + else if (as(modifier)) + isBackwardDifferentiable = true; + } + if (primalSubst) + { + for (auto modifier : primalSubst->modifiers) + { + if (as(modifier)) + isForwardDifferentiable = true; + else if (as(modifier)) + isBackwardDifferentiable = true; + } + } + if (isBackwardDifferentiable) + { + sb << "\nForward and backward differentiable\n\n"; + } + if (isForwardDifferentiable) + { + sb << "\nForward differentiable\n\n"; + } + if (fwdDiff && fwdDiff->getName()) + sb << "Forward derivative: `" << fwdDiff->getName()->text << "`\n\n"; + if (bwdDiff && bwdDiff->getName()) + sb << "Backward derivative: `" << bwdDiff->getName()->text << "`\n\n"; + if (primalSubst && primalSubst->getName()) + sb << "Primal substitute: `" << primalSubst->getName()->text << "`\n\n"; + } + auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc( declRef.getLoc(), SourceLocType::Actual); appendDefinitionLocation(sb, m_workspace, humaneLoc); diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 11729800c..606c77096 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -1413,6 +1413,17 @@ Decl* getParentDecl(Decl* decl) return decl; } +Decl* getParentFunc(Decl* decl) +{ + while (decl) + { + if (as(decl)) + return decl; + decl = decl->parentDecl; + } + return nullptr; +} + static const ImageFormatInfo kImageFormatInfos[] = { #define SLANG_IMAGE_FORMAT_INFO(TYPE, COUNT, SIZE) SLANG_SCALAR_TYPE_##TYPE, uint8_t(COUNT), uint8_t(SIZE) diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index dd119cf3a..3ac258f08 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -337,6 +337,8 @@ namespace Slang /// Get the parent decl, skipping any generic decls in between. Decl* getParentDecl(Decl* decl); + Decl* getParentFunc(Decl* decl); + } // namespace Slang #endif diff --git a/tests/autodiff/primal-substitute-3.slang b/tests/autodiff/primal-substitute-3.slang index ab2899bdc..9c67fbc04 100644 --- a/tests/autodiff/primal-substitute-3.slang +++ b/tests/autodiff/primal-substitute-3.slang @@ -18,6 +18,7 @@ struct A : IFoo } } +// A normal function that calls doSomething(). float original(T p, float x) { p.doSomething(); diff --git a/tests/language-server/ordinary-comment-hover-info.slang.expected.txt b/tests/language-server/ordinary-comment-hover-info.slang.expected.txt index d4fa0328a..72b83c696 100644 --- a/tests/language-server/ordinary-comment-hover-info.slang.expected.txt +++ b/tests/language-server/ordinary-comment-hover-info.slang.expected.txt @@ -2,7 +2,7 @@ range: 7,4 - 7,9 content: ``` -int value +(global variable) int value ``` #1: Ordindary comment for `value`. diff --git a/tests/language-server/robustness-1.slang.expected.txt b/tests/language-server/robustness-1.slang.expected.txt index 4bfd5a31d..28ba6a292 100644 --- a/tests/language-server/robustness-1.slang.expected.txt +++ b/tests/language-server/robustness-1.slang.expected.txt @@ -2,7 +2,7 @@ range: 7,11 - 7,14 content: ``` -vector[3] arr +(local variable) vector[3] arr ``` {REDACTED}.slang(8) diff --git a/tests/language-server/robustness-3.slang.expected.txt b/tests/language-server/robustness-3.slang.expected.txt index 8dcd04228..fe416c3de 100644 --- a/tests/language-server/robustness-3.slang.expected.txt +++ b/tests/language-server/robustness-3.slang.expected.txt @@ -2,7 +2,7 @@ range: 6,20 - 6,25 content: ``` -int index +(parameter) int index ``` diff --git a/tests/language-server/smoke.slang.expected.txt b/tests/language-server/smoke.slang.expected.txt index 0560c2a2d..0b37ed9db 100644 --- a/tests/language-server/smoke.slang.expected.txt +++ b/tests/language-server/smoke.slang.expected.txt @@ -6,7 +6,7 @@ getSum: 2 ,.;:()[]<>{}*&^%!-=+|/? range: 24,26 - 24,31 content: ``` -Pair.T Pair.first +(field) Pair.T Pair.first ``` -- cgit v1.2.3