diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ast-print.cpp | 47 | ||||
| -rw-r--r-- | source/slang/slang-ast-print.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 46 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-language-server-ast-lookup.cpp | 39 | ||||
| -rw-r--r-- | source/slang/slang-language-server.cpp | 59 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 21 |
9 files changed, 154 insertions, 68 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 94dceef54..aa494ec95 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -24461,7 +24461,7 @@ extension< CoopMat<T, S, M, N, R> MapElement(functype(uint32_t, uint32_t, T, expand each Ts)->T mapOp); __intrinsic_op($(kIROp_CoopMatMapElementIFunc)) - static CoopMat<T, S, M, N, R> __MapElement< + internal static CoopMat<T, S, M, N, R> __MapElement< TOperator, TFunc : IFunc<T, uint32_t, uint32_t, T, expand each Ts> >(This tuple, TOperator mapOp, TFunc mapObj); diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index d6027b2b2..1100cdee4 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -800,6 +800,7 @@ class PartiallyAppliedGenericExpr : public Expr FIDDLE(...) public: Expr* originalExpr = nullptr; + Expr* baseExpr = nullptr; /// The generic being applied DeclRef<GenericDecl> baseGenericDeclRef; diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index ee747a4c2..1f7478662 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -1313,12 +1313,34 @@ void ASTPrinter::_addDeclPathRec(const DeclRef<Decl>& declRef, Index depth) else if (depth > 0) { // Write out the generic parameters (only if the depth allows it) - addGenericParams(parentGenericDeclRef); + addGenericParams(parentGenericDeclRef, nullptr); } } } -void ASTPrinter::addGenericParams(const DeclRef<GenericDecl>& genericDeclRef) +struct ParamScope +{ + StringBuilder* sb; + List<Range<Index>>* paramRanges; + Index rangeStart; + ParamScope(StringBuilder* inSb, List<Range<Index>>* outParamRanges) + : sb(inSb), paramRanges(outParamRanges) + { + rangeStart = sb->getLength(); + } + ~ParamScope() + { + if (paramRanges) + { + Index rangeEnd = sb->getLength(); + paramRanges->add(makeRange<Index>(rangeStart, rangeEnd)); + } + } +}; + +void ASTPrinter::addGenericParams( + const DeclRef<GenericDecl>& genericDeclRef, + List<Range<Index>>* outParamRanges) { auto& sb = m_builder; @@ -1331,7 +1353,7 @@ void ASTPrinter::addGenericParams(const DeclRef<GenericDecl>& genericDeclRef) if (!first) sb << ", "; first = false; - + ParamScope paramScope(&sb, outParamRanges); { ScopePart scopePart(this, Part::Type::GenericParamType); sb << getText(genericTypeParam.getName()); @@ -1342,7 +1364,7 @@ void ASTPrinter::addGenericParams(const DeclRef<GenericDecl>& genericDeclRef) if (!first) sb << ", "; first = false; - + ParamScope paramScope(&sb, outParamRanges); { ScopePart scopePart(this, Part::Type::GenericParamValueType); addType(getType(m_astBuilder, genericValParam)); @@ -1358,6 +1380,7 @@ void ASTPrinter::addGenericParams(const DeclRef<GenericDecl>& genericDeclRef) if (!first) sb << ", "; first = false; + ParamScope paramScope(&sb, outParamRanges); { ScopePart scopePart(this, Part::Type::GenericParamType); sb << "each "; @@ -1383,8 +1406,6 @@ void ASTPrinter::addDeclParams(const DeclRef<Decl>& declRef, List<Range<Index>>* bool first = true; for (auto paramDeclRef : getParameters(m_astBuilder, funcDeclRef)) { - auto rangeStart = sb.getLength(); - ParamDecl* paramDecl = paramDeclRef.getDecl(); auto paramType = getType(m_astBuilder, paramDeclRef); @@ -1393,9 +1414,10 @@ void ASTPrinter::addDeclParams(const DeclRef<Decl>& declRef, List<Range<Index>>* if (!first) { sb << ", "; - rangeStart += 2; } + ParamScope paramScope(&sb, outParamRange); + // Type part. { ScopePart scopePart(this, Part::Type::ParamType); @@ -1442,11 +1464,6 @@ void ASTPrinter::addDeclParams(const DeclRef<Decl>& declRef, List<Range<Index>>* sb << " = "; addExpr(paramDecl->initExpr); } - - auto rangeEnd = sb.getLength(); - - if (outParamRange) - outParamRange->add(makeRange<Index>(rangeStart, rangeEnd)); first = false; }; if (auto typePack = as<ConcreteTypePack>(paramType)) @@ -1467,11 +1484,7 @@ void ASTPrinter::addDeclParams(const DeclRef<Decl>& declRef, List<Range<Index>>* } else if (auto genericDeclRef = declRef.as<GenericDecl>()) { - addGenericParams(genericDeclRef); - - addDeclParams( - m_astBuilder->getMemberDeclRef(genericDeclRef, genericDeclRef.getDecl()->inner), - outParamRange); + addGenericParams(genericDeclRef, outParamRange); } else { diff --git a/source/slang/slang-ast-print.h b/source/slang/slang-ast-print.h index 6521fdafe..84ad422b6 100644 --- a/source/slang/slang-ast-print.h +++ b/source/slang/slang-ast-print.h @@ -154,7 +154,9 @@ public: void addDeclSignature(const DeclRef<Decl>& declRef); /// Add generic parameters - void addGenericParams(const DeclRef<GenericDecl>& genericDeclRef); + void addGenericParams( + const DeclRef<GenericDecl>& genericDeclRef, + List<Slang::Range<Index>>* outParamRanges = nullptr); /// Get the specified part type. Returns empty slice if not found UnownedStringSlice getPartSlice(Part::Type partType) const; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f63d1aaeb..31b08e925 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1796,7 +1796,7 @@ Expr* SemanticsVisitor::GetBaseExpr(Expr* expr) } else if (auto partiallyApplied = as<PartiallyAppliedGenericExpr>(expr)) { - return GetBaseExpr(partiallyApplied->originalExpr); + return GetBaseExpr(partiallyApplied->baseExpr); } return nullptr; } @@ -4624,17 +4624,6 @@ Expr* SemanticsVisitor::CheckMatrixSwizzleExpr( bool anyDuplicates = false; int zeroIndexOffset = -1; - if (memberRefExpr->name == getSession()->getCompletionRequestTokenName()) - { - auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; - suggestions.clear(); - suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; - suggestions.swizzleBaseType = - memberRefExpr->baseExpression ? memberRefExpr->baseExpression->type : nullptr; - suggestions.elementCount[0] = baseElementRowCount; - suggestions.elementCount[1] = baseElementColCount; - } - String swizzleText = getText(memberRefExpr->name); auto cursor = swizzleText.begin(); @@ -4783,18 +4772,6 @@ Expr* SemanticsVisitor::checkTupleSwizzleExpr(MemberExpr* memberExpr, TupleType* if (tupleElementCount == 0) return checkGeneralMemberLookupExpr(memberExpr, baseTupleType); - if (memberExpr->name == getSession()->getCompletionRequestTokenName()) - { - auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; - suggestions.clear(); - suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; - suggestions.swizzleBaseType = - memberExpr->baseExpression ? memberExpr->baseExpression->type : nullptr; - suggestions.elementCount[0] = (Index)tupleElementCount; - suggestions.elementCount[1] = 0; - return memberExpr; - } - String swizzleText = getText(memberExpr->name); auto span = swizzleText.getUnownedSlice(); Index pos = 0; @@ -5374,6 +5351,27 @@ Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* bas suggestions.elementCount[0] = 1; suggestions.swizzleBaseType = scalarType; } + else if (auto matrixType = as<MatrixExpressionType>(expr->baseExpression->type)) + { + auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; + suggestions.clear(); + suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; + suggestions.swizzleBaseType = matrixType; + suggestions.elementCount[0] = 0; + suggestions.elementCount[1] = 0; + if (auto rowCount = as<ConstantIntVal>(matrixType->getRowCount())) + suggestions.elementCount[0] = rowCount->getValue(); + if (auto colCount = as<ConstantIntVal>(matrixType->getColumnCount())) + suggestions.elementCount[1] = colCount->getValue(); + } + else if (auto tupleType = as<TupleType>(expr->baseExpression->type)) + { + auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; + suggestions.scopeKind = CompletionSuggestions::ScopeKind::Swizzle; + suggestions.elementCount[0] = tupleType->getMemberCount(); + suggestions.elementCount[1] = 0; + suggestions.swizzleBaseType = tupleType; + } } } return createLookupResultExpr(expr->name, lookupResult, expr->baseExpression, expr->loc, expr); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 5a2b0872f..71e8488d8 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1250,7 +1250,8 @@ Expr* SemanticsVisitor::CompleteOverloadCandidate( { auto expr = m_astBuilder->create<PartiallyAppliedGenericExpr>(); expr->loc = context.loc; - expr->originalExpr = baseExpr; + expr->originalExpr = context.originalExpr; + expr->baseExpr = baseExpr; expr->baseGenericDeclRef = as<DeclRefExpr>(baseExpr)->declRef.as<GenericDecl>(); auto args = tryGetGenericArguments(candidate.subst, expr->baseGenericDeclRef.getDecl()); diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 1f01baadc..aa3040f08 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -166,6 +166,26 @@ public: return dispatchIfNotNull(expr->right); } + bool visitAppExprCommon(AppExprBase* expr) + { + if (context->findType == ASTLookupType::Invoke && expr->argumentDelimeterLocs.getCount()) + { + String fileName; + Loc start = context->getLoc(expr->argumentDelimeterLocs.getFirst(), &fileName); + Loc end = context->getLoc(expr->argumentDelimeterLocs.getLast(), nullptr); + if (fileName.getUnownedSlice().endsWithCaseInsensitive(context->sourceFileName) && + start < context->cursorLoc && context->cursorLoc <= end) + { + ASTLookupResult result; + result.path = context->nodePath; + result.path.add(expr); + context->results.add(result); + return true; + } + } + return false; + } + bool visitGenericAppExpr(GenericAppExpr* genericAppExpr) { PushNode pushNodeRAII(context, genericAppExpr); @@ -174,6 +194,8 @@ public: for (auto arg : genericAppExpr->arguments) if (dispatchIfNotNull(arg)) return true; + if (visitAppExprCommon(genericAppExpr)) + return true; return false; } @@ -189,21 +211,8 @@ public: for (auto arg : expr->arguments) if (dispatchIfNotNull(arg)) return true; - if (context->findType == ASTLookupType::Invoke && expr->argumentDelimeterLocs.getCount()) - { - String fileName; - Loc start = context->getLoc(expr->argumentDelimeterLocs.getFirst(), &fileName); - Loc end = context->getLoc(expr->argumentDelimeterLocs.getLast(), nullptr); - if (fileName.getUnownedSlice().endsWithCaseInsensitive(context->sourceFileName) && - start < context->cursorLoc && context->cursorLoc <= end) - { - ASTLookupResult result; - result.path = context->nodePath; - result.path.add(expr); - context->results.add(result); - return true; - } - } + if (visitAppExprCommon(expr)) + return true; return false; } diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index 048fcb44b..dcf7419bd 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -160,6 +160,7 @@ SlangResult LanguageServer::parseNextMessage() caps.semanticTokensProvider.full = true; caps.semanticTokensProvider.range = false; caps.signatureHelpProvider.triggerCharacters.add("("); + caps.signatureHelpProvider.triggerCharacters.add("<"); caps.signatureHelpProvider.triggerCharacters.add(","); caps.signatureHelpProvider.retriggerCharacters.add(","); for (auto tokenType : kSemanticTokenTypes) @@ -1705,14 +1706,33 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore:: bool useOriginalExpr = true; if (auto originalDeclRefExpr = as<DeclRefExpr>(appExpr->originalFunctionExpr)) { + // If the original expr doesn't map to a valid declref, we will use the checked + // func expr instead. if (!originalDeclRefExpr->declRef) { useOriginalExpr = false; } } + if (as<GenericAppExpr>(appExpr->originalFunctionExpr)) + { + if (as<DeclRefExpr>(funcExpr)) + { + // If the original function is a fully specialized generic app, use the checked func + // expr for signature help. + useOriginalExpr = false; + } + } if (useOriginalExpr) funcExpr = appExpr->originalFunctionExpr; } + if (auto partialGenAppExpr = as<PartiallyAppliedGenericExpr>(funcExpr)) + { + funcExpr = partialGenAppExpr->originalExpr; + } + if (auto genAppExpr = as<GenericAppExpr>(funcExpr)) + { + funcExpr = genAppExpr->functionExpr; + } if (!funcExpr) { return std::nullopt; @@ -1741,6 +1761,22 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore:: if (!declRef.getDecl()) return; + // If funcExpr is a direct reference to a generic, we should either + // show the generic signature if we are inside `<>`, or show the function + // parameter signature if we are inside `()`. If we are inside `()`, we will + // need to form a decl ref to the inner decl and show its signature. + if (!as<GenericAppExpr>(appExpr)) + { + if (auto genDeclRef = as<GenericDecl>(declRef)) + { + declRef = createDefaultSubstitutionsIfNeeded( + version->linkage->getASTBuilder(), + &semanticsVisitor, + version->linkage->getASTBuilder()->getMemberDeclRef( + declRef, + genDeclRef.getDecl()->inner)); + } + } // If we have a better match than the current best, we will update response.activeSignature // to this signature. if (auto callableDeclRef = declRef.as<CallableDecl>()) @@ -1832,12 +1868,20 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore:: { if (auto typeType = as<TypeType>(declRefExpr->type.type)) { - // Look for initializers - auto ctors = - semanticsVisitor.lookupConstructorsInType(typeType->getType(), declRefExpr->scope); - for (auto ctor : ctors) + if (as<GenericDeclRefType>(typeType->getType())) { - addDeclRef(ctor.declRef); + addDeclRef(declRefExpr->declRef); + } + else + { + // Look for initializers + auto ctors = semanticsVisitor.lookupConstructorsInType( + typeType->getType(), + declRefExpr->scope); + for (auto ctor : ctors) + { + addDeclRef(ctor.declRef); + } } } else @@ -1847,8 +1891,13 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore:: } else if (auto overloadedExpr = as<OverloadedExpr>(funcExpr)) { + bool isGenApp = as<GenericAppExpr>(appExpr) != nullptr; for (auto item : overloadedExpr->lookupResult2) { + // Skip non-generic candidates if we are inside a generic app expr (e.g. + // `f<WE_ARE_HERE>`). + if (isGenApp && !as<GenericDecl>(item.declRef)) + continue; addDeclRef(item.declRef); } } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 891db2867..7f7c39929 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2356,14 +2356,27 @@ static Expr* parseGenericApp(Parser* parser, Expr* base) genericApp->loc = base->loc; genericApp->functionExpr = base; - parser->ReadToken(TokenType::OpLess); + auto opLess = parser->ReadToken(TokenType::OpLess); + genericApp->argumentDelimeterLocs.add(opLess.loc); parser->genericDepth++; - // For now assume all generics have at least one argument - genericApp->arguments.add(_parseGenericArg(parser)); - while (AdvanceIf(parser, TokenType::Comma)) + + for (;;) { + if (parser->LookAheadToken(TokenType::OpGreater) || + parser->LookAheadToken(TokenType::OpRsh)) + break; genericApp->arguments.add(_parseGenericArg(parser)); + if (parser->LookAheadToken(TokenType::Comma)) + { + auto commaToken = parser->ReadToken(TokenType::Comma); + genericApp->argumentDelimeterLocs.add(commaToken.loc); + } + else + { + break; + } } + genericApp->argumentDelimeterLocs.add(parser->tokenReader.peekLoc()); parser->genericDepth--; if (parser->tokenReader.peekToken().type == TokenType::OpRsh) |
