diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-26 22:40:06 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-26 22:40:06 -0700 |
| commit | 7be108c379ccc7da3f46b30a2b5917104155d52b (patch) | |
| tree | b38b6d5483ba63b18d38c282c06dd55ff9c188ea /source/slang/slang-language-server.cpp | |
| parent | 3acbe8145c60f4d1e7a180b4602a94269a489df5 (diff) | |
Intellisense: show info on decl kind and differentiability. (#2847)
Diffstat (limited to 'source/slang/slang-language-server.cpp')
| -rw-r--r-- | source/slang/slang-language-server.cpp | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index a2cafa55a..b826d7cfe 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -216,6 +216,55 @@ static bool isBoolType(Type* t) return basicType->baseType == BaseType::Bool; } +String getDeclKindString(DeclRef<Decl> declRef) +{ + if (declRef.as<ParamDecl>()) + { + return "(parameter) "; + } + else if (declRef.as<GenericTypeParamDecl>()) + { + return "(generic type parameter) "; + } + else if (declRef.as<GenericValueParamDecl>()) + { + return "(generic value parameter) "; + } + else if (declRef.as<AttributeDecl>()) + { + return "(attribute) "; + } + else if (auto varDecl = declRef.as<VarDeclBase>()) + { + auto parent = declRef.getParent(); + if (as<GenericDecl>(parent)) + parent = parent.getParent(); + if (parent.as<InterfaceDecl>()) + { + return "(associated constant) "; + } + else if (parent.as<AggTypeDeclBase>()) + { + return "(field) "; + } + const char* scopeKind = ""; + if (parent.as<NamespaceDeclBase>()) + scopeKind = "global "; + else if (getParentDecl(declRef.getDecl())) + scopeKind = "local "; + StringBuilder sb; + sb << "("; + sb << scopeKind; + if (varDecl.as<LetDecl>()) + sb << "value"; + else + sb << "variable"; + sb << ") "; + return sb.produceString(); + } + return String(); +} + String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version) { if (declRef.getDecl()) @@ -225,6 +274,7 @@ String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version) astBuilder, ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | ASTPrinter::OptionFlag::SimplifiedBuiltinType); + printer.getStringBuilder() << getDeclKindString(declRef); printer.addDeclSignature(declRef); if (auto varDecl = as<VarDeclBase>(declRef.getDecl())) { @@ -496,6 +546,72 @@ SlangResult LanguageServer::hover( _tryGetDocumentation(sb, version, declRef.getDecl()); + if (auto funcDecl = as<FunctionDeclBase>(declRef.getDecl())) + { + DiagnosticSink sink; + SharedSemanticsContext semanticContext(version->linkage, getModule(funcDecl), &sink); + SemanticsVisitor semanticsVisitor(&semanticContext); + + auto assocDecls = semanticContext.getAssociatedDeclsForDecl(funcDecl); + Decl* bwdDiff = nullptr; + Decl* fwdDiff = nullptr; + Decl* primalSubst = nullptr; + auto getDeclFromExpr = [&](Expr* expr) -> Decl* + { + if (auto declRefExpr = as<DeclRefExpr>(expr)) + return declRefExpr->declRef.getDecl(); + return nullptr; + }; + for (auto& assocDecl : assocDecls) + { + if (assocDecl->kind == DeclAssociationKind::ForwardDerivativeFunc) + fwdDiff = assocDecl->decl; + else if (assocDecl->kind == DeclAssociationKind::BackwardDerivativeFunc) + bwdDiff = assocDecl->decl; + else if (assocDecl->kind == DeclAssociationKind::PrimalSubstituteFunc) + primalSubst = assocDecl->decl; + } + bool isBackwardDifferentiable = false; + bool isForwardDifferentiable = false; + for (auto modifier : funcDecl->modifiers) + { + if (auto bwdDiffModifier = as<BackwardDerivativeAttribute>(modifier)) + bwdDiff = getDeclFromExpr(bwdDiffModifier->funcExpr); + else if (auto fwdDiffModifier = as<ForwardDerivativeAttribute>(modifier)) + fwdDiff = getDeclFromExpr(fwdDiffModifier->funcExpr); + else if (auto primalSubstModifier = as<PrimalSubstituteAttribute>(modifier)) + primalSubst = getDeclFromExpr(primalSubstModifier->funcExpr); + else if (as<ForwardDifferentiableAttribute>(modifier)) + isForwardDifferentiable = true; + else if (as<BackwardDifferentiableAttribute>(modifier)) + isBackwardDifferentiable = true; + } + if (primalSubst) + { + for (auto modifier : primalSubst->modifiers) + { + if (as<ForwardDifferentiableAttribute>(modifier)) + isForwardDifferentiable = true; + else if (as<BackwardDifferentiableAttribute>(modifier)) + isBackwardDifferentiable = true; + } + } + if (isBackwardDifferentiable) + { + sb << "\nForward and backward differentiable\n\n"; + } + if (isForwardDifferentiable) + { + sb << "\nForward differentiable\n\n"; + } + if (fwdDiff && fwdDiff->getName()) + sb << "Forward derivative: `" << fwdDiff->getName()->text << "`\n\n"; + if (bwdDiff && bwdDiff->getName()) + sb << "Backward derivative: `" << bwdDiff->getName()->text << "`\n\n"; + if (primalSubst && primalSubst->getName()) + sb << "Primal substitute: `" << primalSubst->getName()->text << "`\n\n"; + } + auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc( declRef.getLoc(), SourceLocType::Actual); appendDefinitionLocation(sb, m_workspace, humaneLoc); |
