From 7be108c379ccc7da3f46b30a2b5917104155d52b Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 26 Apr 2023 22:40:06 -0700 Subject: Intellisense: show info on decl kind and differentiability. (#2847) --- source/slang/slang-language-server.cpp | 116 +++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) (limited to 'source/slang/slang-language-server.cpp') diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index a2cafa55a..b826d7cfe 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -216,6 +216,55 @@ static bool isBoolType(Type* t) return basicType->baseType == BaseType::Bool; } +String getDeclKindString(DeclRef declRef) +{ + if (declRef.as()) + { + return "(parameter) "; + } + else if (declRef.as()) + { + return "(generic type parameter) "; + } + else if (declRef.as()) + { + return "(generic value parameter) "; + } + else if (declRef.as()) + { + return "(attribute) "; + } + else if (auto varDecl = declRef.as()) + { + auto parent = declRef.getParent(); + if (as(parent)) + parent = parent.getParent(); + if (parent.as()) + { + return "(associated constant) "; + } + else if (parent.as()) + { + return "(field) "; + } + const char* scopeKind = ""; + if (parent.as()) + scopeKind = "global "; + else if (getParentDecl(declRef.getDecl())) + scopeKind = "local "; + StringBuilder sb; + sb << "("; + sb << scopeKind; + if (varDecl.as()) + sb << "value"; + else + sb << "variable"; + sb << ") "; + return sb.produceString(); + } + return String(); +} + String getDeclSignatureString(DeclRef declRef, WorkspaceVersion* version) { if (declRef.getDecl()) @@ -225,6 +274,7 @@ String getDeclSignatureString(DeclRef declRef, WorkspaceVersion* version) astBuilder, ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | ASTPrinter::OptionFlag::SimplifiedBuiltinType); + printer.getStringBuilder() << getDeclKindString(declRef); printer.addDeclSignature(declRef); if (auto varDecl = as(declRef.getDecl())) { @@ -496,6 +546,72 @@ SlangResult LanguageServer::hover( _tryGetDocumentation(sb, version, declRef.getDecl()); + if (auto funcDecl = as(declRef.getDecl())) + { + DiagnosticSink sink; + SharedSemanticsContext semanticContext(version->linkage, getModule(funcDecl), &sink); + SemanticsVisitor semanticsVisitor(&semanticContext); + + auto assocDecls = semanticContext.getAssociatedDeclsForDecl(funcDecl); + Decl* bwdDiff = nullptr; + Decl* fwdDiff = nullptr; + Decl* primalSubst = nullptr; + auto getDeclFromExpr = [&](Expr* expr) -> Decl* + { + if (auto declRefExpr = as(expr)) + return declRefExpr->declRef.getDecl(); + return nullptr; + }; + for (auto& assocDecl : assocDecls) + { + if (assocDecl->kind == DeclAssociationKind::ForwardDerivativeFunc) + fwdDiff = assocDecl->decl; + else if (assocDecl->kind == DeclAssociationKind::BackwardDerivativeFunc) + bwdDiff = assocDecl->decl; + else if (assocDecl->kind == DeclAssociationKind::PrimalSubstituteFunc) + primalSubst = assocDecl->decl; + } + bool isBackwardDifferentiable = false; + bool isForwardDifferentiable = false; + for (auto modifier : funcDecl->modifiers) + { + if (auto bwdDiffModifier = as(modifier)) + bwdDiff = getDeclFromExpr(bwdDiffModifier->funcExpr); + else if (auto fwdDiffModifier = as(modifier)) + fwdDiff = getDeclFromExpr(fwdDiffModifier->funcExpr); + else if (auto primalSubstModifier = as(modifier)) + primalSubst = getDeclFromExpr(primalSubstModifier->funcExpr); + else if (as(modifier)) + isForwardDifferentiable = true; + else if (as(modifier)) + isBackwardDifferentiable = true; + } + if (primalSubst) + { + for (auto modifier : primalSubst->modifiers) + { + if (as(modifier)) + isForwardDifferentiable = true; + else if (as(modifier)) + isBackwardDifferentiable = true; + } + } + if (isBackwardDifferentiable) + { + sb << "\nForward and backward differentiable\n\n"; + } + if (isForwardDifferentiable) + { + sb << "\nForward differentiable\n\n"; + } + if (fwdDiff && fwdDiff->getName()) + sb << "Forward derivative: `" << fwdDiff->getName()->text << "`\n\n"; + if (bwdDiff && bwdDiff->getName()) + sb << "Backward derivative: `" << bwdDiff->getName()->text << "`\n\n"; + if (primalSubst && primalSubst->getName()) + sb << "Primal substitute: `" << primalSubst->getName()->text << "`\n\n"; + } + auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc( declRef.getLoc(), SourceLocType::Actual); appendDefinitionLocation(sb, m_workspace, humaneLoc); -- cgit v1.2.3