diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-26 22:40:06 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-26 22:40:06 -0700 |
| commit | 7be108c379ccc7da3f46b30a2b5917104155d52b (patch) | |
| tree | b38b6d5483ba63b18d38c282c06dd55ff9c188ea | |
| parent | 3acbe8145c60f4d1e7a180b4602a94269a489df5 (diff) | |
Intellisense: show info on decl kind and differentiability. (#2847)
| -rw-r--r-- | source/slang/slang-ast-print.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-language-server.cpp | 116 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 2 | ||||
| -rw-r--r-- | tests/autodiff/primal-substitute-3.slang | 1 | ||||
| -rw-r--r-- | tests/language-server/ordinary-comment-hover-info.slang.expected.txt | 2 | ||||
| -rw-r--r-- | tests/language-server/robustness-1.slang.expected.txt | 2 | ||||
| -rw-r--r-- | tests/language-server/robustness-3.slang.expected.txt | 2 | ||||
| -rw-r--r-- | tests/language-server/smoke.slang.expected.txt | 2 |
9 files changed, 137 insertions, 5 deletions
diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index b38b58385..23430cd13 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -168,7 +168,9 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth) // If the parent declaration is a generic, then we need to print out its // signature - if (parentGenericDeclRef) + if (parentGenericDeclRef && + !declRef.as<GenericValueParamDecl>() && + !declRef.as<GenericTypeParamDecl>()) { auto genSubst = as<GenericSubstitution>(declRef.substitutions.substitutions); if (genSubst) diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index a2cafa55a..b826d7cfe 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -216,6 +216,55 @@ static bool isBoolType(Type* t) return basicType->baseType == BaseType::Bool; } +String getDeclKindString(DeclRef<Decl> declRef) +{ + if (declRef.as<ParamDecl>()) + { + return "(parameter) "; + } + else if (declRef.as<GenericTypeParamDecl>()) + { + return "(generic type parameter) "; + } + else if (declRef.as<GenericValueParamDecl>()) + { + return "(generic value parameter) "; + } + else if (declRef.as<AttributeDecl>()) + { + return "(attribute) "; + } + else if (auto varDecl = declRef.as<VarDeclBase>()) + { + auto parent = declRef.getParent(); + if (as<GenericDecl>(parent)) + parent = parent.getParent(); + if (parent.as<InterfaceDecl>()) + { + return "(associated constant) "; + } + else if (parent.as<AggTypeDeclBase>()) + { + return "(field) "; + } + const char* scopeKind = ""; + if (parent.as<NamespaceDeclBase>()) + scopeKind = "global "; + else if (getParentDecl(declRef.getDecl())) + scopeKind = "local "; + StringBuilder sb; + sb << "("; + sb << scopeKind; + if (varDecl.as<LetDecl>()) + sb << "value"; + else + sb << "variable"; + sb << ") "; + return sb.produceString(); + } + return String(); +} + String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version) { if (declRef.getDecl()) @@ -225,6 +274,7 @@ String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version) astBuilder, ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | ASTPrinter::OptionFlag::SimplifiedBuiltinType); + printer.getStringBuilder() << getDeclKindString(declRef); printer.addDeclSignature(declRef); if (auto varDecl = as<VarDeclBase>(declRef.getDecl())) { @@ -496,6 +546,72 @@ SlangResult LanguageServer::hover( _tryGetDocumentation(sb, version, declRef.getDecl()); + if (auto funcDecl = as<FunctionDeclBase>(declRef.getDecl())) + { + DiagnosticSink sink; + SharedSemanticsContext semanticContext(version->linkage, getModule(funcDecl), &sink); + SemanticsVisitor semanticsVisitor(&semanticContext); + + auto assocDecls = semanticContext.getAssociatedDeclsForDecl(funcDecl); + Decl* bwdDiff = nullptr; + Decl* fwdDiff = nullptr; + Decl* primalSubst = nullptr; + auto getDeclFromExpr = [&](Expr* expr) -> Decl* + { + if (auto declRefExpr = as<DeclRefExpr>(expr)) + return declRefExpr->declRef.getDecl(); + return nullptr; + }; + for (auto& assocDecl : assocDecls) + { + if (assocDecl->kind == DeclAssociationKind::ForwardDerivativeFunc) + fwdDiff = assocDecl->decl; + else if (assocDecl->kind == DeclAssociationKind::BackwardDerivativeFunc) + bwdDiff = assocDecl->decl; + else if (assocDecl->kind == DeclAssociationKind::PrimalSubstituteFunc) + primalSubst = assocDecl->decl; + } + bool isBackwardDifferentiable = false; + bool isForwardDifferentiable = false; + for (auto modifier : funcDecl->modifiers) + { + if (auto bwdDiffModifier = as<BackwardDerivativeAttribute>(modifier)) + bwdDiff = getDeclFromExpr(bwdDiffModifier->funcExpr); + else if (auto fwdDiffModifier = as<ForwardDerivativeAttribute>(modifier)) + fwdDiff = getDeclFromExpr(fwdDiffModifier->funcExpr); + else if (auto primalSubstModifier = as<PrimalSubstituteAttribute>(modifier)) + primalSubst = getDeclFromExpr(primalSubstModifier->funcExpr); + else if (as<ForwardDifferentiableAttribute>(modifier)) + isForwardDifferentiable = true; + else if (as<BackwardDifferentiableAttribute>(modifier)) + isBackwardDifferentiable = true; + } + if (primalSubst) + { + for (auto modifier : primalSubst->modifiers) + { + if (as<ForwardDifferentiableAttribute>(modifier)) + isForwardDifferentiable = true; + else if (as<BackwardDifferentiableAttribute>(modifier)) + isBackwardDifferentiable = true; + } + } + if (isBackwardDifferentiable) + { + sb << "\nForward and backward differentiable\n\n"; + } + if (isForwardDifferentiable) + { + sb << "\nForward differentiable\n\n"; + } + if (fwdDiff && fwdDiff->getName()) + sb << "Forward derivative: `" << fwdDiff->getName()->text << "`\n\n"; + if (bwdDiff && bwdDiff->getName()) + sb << "Backward derivative: `" << bwdDiff->getName()->text << "`\n\n"; + if (primalSubst && primalSubst->getName()) + sb << "Primal substitute: `" << primalSubst->getName()->text << "`\n\n"; + } + auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc( declRef.getLoc(), SourceLocType::Actual); appendDefinitionLocation(sb, m_workspace, humaneLoc); diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 11729800c..606c77096 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -1413,6 +1413,17 @@ Decl* getParentDecl(Decl* decl) return decl; } +Decl* getParentFunc(Decl* decl) +{ + while (decl) + { + if (as<FunctionDeclBase>(decl)) + return decl; + decl = decl->parentDecl; + } + return nullptr; +} + static const ImageFormatInfo kImageFormatInfos[] = { #define SLANG_IMAGE_FORMAT_INFO(TYPE, COUNT, SIZE) SLANG_SCALAR_TYPE_##TYPE, uint8_t(COUNT), uint8_t(SIZE) diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index dd119cf3a..3ac258f08 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -337,6 +337,8 @@ namespace Slang /// Get the parent decl, skipping any generic decls in between. Decl* getParentDecl(Decl* decl); + Decl* getParentFunc(Decl* decl); + } // namespace Slang #endif diff --git a/tests/autodiff/primal-substitute-3.slang b/tests/autodiff/primal-substitute-3.slang index ab2899bdc..9c67fbc04 100644 --- a/tests/autodiff/primal-substitute-3.slang +++ b/tests/autodiff/primal-substitute-3.slang @@ -18,6 +18,7 @@ struct A : IFoo } } +// A normal function that calls doSomething(). float original<T : IFoo>(T p, float x) { p.doSomething(); diff --git a/tests/language-server/ordinary-comment-hover-info.slang.expected.txt b/tests/language-server/ordinary-comment-hover-info.slang.expected.txt index d4fa0328a..72b83c696 100644 --- a/tests/language-server/ordinary-comment-hover-info.slang.expected.txt +++ b/tests/language-server/ordinary-comment-hover-info.slang.expected.txt @@ -2,7 +2,7 @@ range: 7,4 - 7,9 content: ``` -int value +(global variable) int value ``` #1: Ordindary comment for `value`. diff --git a/tests/language-server/robustness-1.slang.expected.txt b/tests/language-server/robustness-1.slang.expected.txt index 4bfd5a31d..28ba6a292 100644 --- a/tests/language-server/robustness-1.slang.expected.txt +++ b/tests/language-server/robustness-1.slang.expected.txt @@ -2,7 +2,7 @@ range: 7,11 - 7,14 content: ``` -vector<float,3>[3] arr +(local variable) vector<float,3>[3] arr ``` {REDACTED}.slang(8) diff --git a/tests/language-server/robustness-3.slang.expected.txt b/tests/language-server/robustness-3.slang.expected.txt index 8dcd04228..fe416c3de 100644 --- a/tests/language-server/robustness-3.slang.expected.txt +++ b/tests/language-server/robustness-3.slang.expected.txt @@ -2,7 +2,7 @@ range: 6,20 - 6,25 content: ``` -int index +(parameter) int index ``` diff --git a/tests/language-server/smoke.slang.expected.txt b/tests/language-server/smoke.slang.expected.txt index 0560c2a2d..0b37ed9db 100644 --- a/tests/language-server/smoke.slang.expected.txt +++ b/tests/language-server/smoke.slang.expected.txt @@ -6,7 +6,7 @@ getSum: 2 ,.;:()[]<>{}*&^%!-=+|/? range: 24,26 - 24,31 content: ``` -Pair.T Pair<Pair.T, Pair.U>.first +(field) Pair.T Pair<Pair.T, Pair.U>.first ``` |
