diff options
| author | Yong He <yonghe@outlook.com> | 2025-08-07 08:10:02 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-08-07 15:10:02 +0000 |
| commit | 7cd8130e1a3dbcca8746e0577fb8df3bf2975bf8 (patch) | |
| tree | 6da2b411da34039c3d0ec0e06fadd0e13f8f4842 | |
| parent | 67a96920674d628f615532a302504544a45e8187 (diff) | |
Support `expand` on concrete tuple values. (#8106)
Closes #8061.
Along with the fix, also enhanced coercion/overload resolution to filter
candidates based on the target type, allowing
`tests\language-feature\higher-order-functions\overloaded.slang` to
pass.
| -rw-r--r-- | source/slang/slang-ast-stmt.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 30 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 28 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 116 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 61 | ||||
| -rw-r--r-- | source/slang/slang-parser.h | 1 | ||||
| -rw-r--r-- | tests/language-feature/higher-order-functions/overloaded.slang | 13 | ||||
| -rw-r--r-- | tests/language-feature/higher-order-functions/overloaded.slang.expected | 8 | ||||
| -rw-r--r-- | tests/language-feature/tuple/tuple-expand-call.slang | 26 |
12 files changed, 256 insertions, 79 deletions
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index 8a1b58aff..f5436a70e 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -56,7 +56,6 @@ class UnparsedStmt : public Stmt Scope* currentScope = nullptr; Scope* outerScope = nullptr; SourceLanguage sourceLanguage; - bool isInVariadicGenerics = false; }; FIDDLE() diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index a605928e3..758c23a5f 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1191,10 +1191,10 @@ bool SemanticsVisitor::_coerce( // then we should start by trying to resolve the ambiguous reference // based on prioritization of the different candidates. // - // TODO: A more powerful model would be to try to coerce each + // If `fromExpr` is overloaded, we will try to coerce each // of the constituent overload candidates, filtering down to // those that are coercible, and then disambiguating the result. - // Such an approach would let us disambiguate between overloaded + // Such an approach lets us disambiguate between overloaded // symbols based on their type (e.g., by casting the name of // an overloaded function to the type of the overload we mean // to reference). @@ -1202,11 +1202,49 @@ bool SemanticsVisitor::_coerce( if (auto fromOverloadedExpr = as<OverloadedExpr>(fromExpr)) { auto resolvedExpr = - maybeResolveOverloadedExpr(fromOverloadedExpr, LookupMask::Default, nullptr); + maybeResolveOverloadedExpr(fromOverloadedExpr, LookupMask::Default, toType, nullptr); fromExpr = resolvedExpr; fromType = resolvedExpr->type; } + else if (auto overloadedExpr2 = as<OverloadedExpr2>(fromExpr)) + { + ShortList<Expr*> coercibleCandidates; + for (auto candidate : overloadedExpr2->candidateExprs) + { + if (canCoerce(toType, candidate->type, candidate)) + coercibleCandidates.add(candidate); + } + if (coercibleCandidates.getCount() == 1) + { + return _coerce( + site, + toType, + outToExpr, + coercibleCandidates[0]->type, + coercibleCandidates[0], + sink, + outCost); + } + if (sink) + { + auto firstCandidate = overloadedExpr2->candidateExprs.getCount() > 0 + ? overloadedExpr2->candidateExprs[0] + : nullptr; + if (auto declCandidate = as<DeclRefExpr>(firstCandidate)) + { + sink->diagnose( + fromExpr->loc, + Diagnostics::ambiguousReference, + declCandidate->declRef); + } + else + { + sink->diagnose(fromExpr->loc, Diagnostics::ambiguousExpression); + } + } + return false; + } // An important and easy case is when the "to" and "from" types are equal. // diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index dd785f0a8..d6bc50b82 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -10424,7 +10424,6 @@ Stmt* SemanticsVisitor::maybeParseStmt(Stmt* stmt, const SemanticsContext& conte &subVisitor, getShared()->getTranslationUnitRequest(), unparsedStmt->sourceLanguage, - unparsedStmt->isInVariadicGenerics, tokenList, getShared()->getSink(), unparsedStmt->currentScope, diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 425548518..a874eaf43 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1128,7 +1128,9 @@ LookupResult SemanticsVisitor::filterLookupResultByCheckedOptionalAndDiagnose( return result; } -LookupResult SemanticsVisitor::resolveOverloadedLookup(LookupResult const& inResult) +LookupResult SemanticsVisitor::resolveOverloadedLookup( + LookupResult const& inResult, + Type* targetType) { // If the result isn't actually overloaded, it is fine as-is if (!inResult.isValid()) @@ -1140,6 +1142,15 @@ LookupResult SemanticsVisitor::resolveOverloadedLookup(LookupResult const& inRes List<LookupResultItem> items; for (auto item : inResult.items) { + // First we check if the item is coercible to targetType. + // And skip if it doesn't. + if (targetType) + { + auto declType = GetTypeForDeclRef(item.declRef, SourceLoc()); + if (!canCoerce(targetType, declType, nullptr, nullptr)) + continue; + } + // For each item we consider adding, we will compare it // to those items we've already added. // @@ -1232,6 +1243,7 @@ void SemanticsVisitor::diagnoseAmbiguousReference(Expr* expr) Expr* SemanticsVisitor::_resolveOverloadedExprImpl( OverloadedExpr* overloadedExpr, LookupMask mask, + Type* targetType, DiagnosticSink* diagSink) { auto lookupResult = overloadedExpr->lookupResult2; @@ -1246,7 +1258,7 @@ Expr* SemanticsVisitor::_resolveOverloadedExprImpl( lookupResult = refineLookup(lookupResult, mask); // Try to filter out overload candidates based on which ones are "better" than one another. - lookupResult = resolveOverloadedLookup(lookupResult); + lookupResult = resolveOverloadedLookup(lookupResult, targetType); if (!lookupResult.isValid()) { @@ -1296,6 +1308,7 @@ Expr* SemanticsVisitor::_resolveOverloadedExprImpl( Expr* SemanticsVisitor::maybeResolveOverloadedExpr( Expr* expr, LookupMask mask, + Type* targetType, DiagnosticSink* diagSink) { if (IsErrorExpr(expr)) @@ -1303,7 +1316,7 @@ Expr* SemanticsVisitor::maybeResolveOverloadedExpr( if (auto overloadedExpr = as<OverloadedExpr>(expr)) { - return _resolveOverloadedExprImpl(overloadedExpr, mask, diagSink); + return _resolveOverloadedExprImpl(overloadedExpr, mask, targetType, diagSink); } else { @@ -1311,9 +1324,12 @@ Expr* SemanticsVisitor::maybeResolveOverloadedExpr( } } -Expr* SemanticsVisitor::resolveOverloadedExpr(OverloadedExpr* overloadedExpr, LookupMask mask) +Expr* SemanticsVisitor::resolveOverloadedExpr( + OverloadedExpr* overloadedExpr, + Type* targetType, + LookupMask mask) { - return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink()); + return _resolveOverloadedExprImpl(overloadedExpr, mask, targetType, getSink()); } Type* SemanticsVisitor::tryGetDifferentialType(ASTBuilder* builder, Type* type) @@ -1364,7 +1380,7 @@ Type* SemanticsVisitor::tryGetDifferentialType(ASTBuilder* builder, Type* type) Slang::LookupMask::type, Slang::LookupOptions::None); - diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult); + diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult, nullptr); if (!diffTypeLookupResult.isValid()) { @@ -4450,7 +4466,7 @@ Expr* SemanticsExprVisitor::visitEachExpr(EachExpr* expr) { goto error; } - if (!declRefType->getDeclRef().as<GenericTypePackParamDecl>()) + if (!declRefType->getDeclRef().as<GenericTypePackParamDecl>() && !as<TupleType>(baseType)) { goto error; } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 48b06eb9f..94c595321 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1422,7 +1422,11 @@ public: DeclRef<Decl> resolveDeclRef(DeclRef<Decl> declRef); /// Attempt to "resolve" an overloaded `LookupResult` to only include the "best" results - LookupResult resolveOverloadedLookup(LookupResult const& lookupResult); + LookupResult resolveOverloadedLookup(LookupResult const& lookupResult, Type* targetType); + inline LookupResult resolveOverloadedLookup(LookupResult const& lookupResult) + { + return resolveOverloadedLookup(lookupResult, nullptr); + } /// Attempt to resolve `expr` into an expression that refers to a single declaration/value. /// If `expr` isn't overloaded, then it will be returned as-is. @@ -1434,19 +1438,33 @@ public: /// appropriate "ambiguous reference" error will be reported, and an error expression will be /// returned. Otherwise, the original expression is returned if resolution fails. /// - Expr* maybeResolveOverloadedExpr(Expr* expr, LookupMask mask, DiagnosticSink* diagSink); + Expr* maybeResolveOverloadedExpr( + Expr* expr, + LookupMask mask, + Type* targetType, + DiagnosticSink* diagSink); + + inline Expr* maybeResolveOverloadedExpr(Expr* expr, LookupMask mask, DiagnosticSink* diagSink) + { + return maybeResolveOverloadedExpr(expr, mask, nullptr, diagSink); + } /// Attempt to resolve `overloadedExpr` into an expression that refers to a single /// declaration/value. /// /// Equivalent to `maybeResolveOverloadedExpr` with `diagSink` bound to the sink for the /// `SemanticsVisitor`. - Expr* resolveOverloadedExpr(OverloadedExpr* overloadedExpr, LookupMask mask); + Expr* resolveOverloadedExpr(OverloadedExpr* overloadedExpr, Type* targetType, LookupMask mask); + inline Expr* resolveOverloadedExpr(OverloadedExpr* overloadedExpr, LookupMask mask) + { + return resolveOverloadedExpr(overloadedExpr, nullptr, mask); + } /// Worker reoutine for `maybeResolveOverloadedExpr` and `resolveOverloadedExpr`. Expr* _resolveOverloadedExprImpl( OverloadedExpr* overloadedExpr, LookupMask mask, + Type* targetType, DiagnosticSink* diagSink); void diagnoseAmbiguousReference( @@ -2857,6 +2875,10 @@ public: void AddGenericOverloadCandidates(Expr* baseExpr, OverloadResolveContext& context); + // Given an argument list, expand all `expand` expressions, if the type/value pack being + // expanded is already specialized. + void maybeExpandArgList(List<Expr*>& args); + template<class T> void trySetGenericToRayTracingWithParamAttribute( LookupResultItem genericItem, diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 2c17a6380..f949e2632 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -2065,6 +2065,94 @@ void SemanticsVisitor::AddCtorOverloadCandidate( AddOverloadCandidate(context, candidate, baseCost); } +void SemanticsVisitor::maybeExpandArgList(List<Expr*>& args) +{ + bool needExpansion = false; + for (auto expr : args) + { + while (auto paren = as<ParenExpr>(expr)) + expr = paren->base; + + if (auto expand = as<ExpandExpr>(expr)) + { + auto exprType = expand->type.type; + if (auto typeType = as<TypeType>(exprType)) + exprType = typeType->getType(); + if (as<ConcreteTypePack>(exprType)) + { + needExpansion = true; + } + } + } + // Fast path without creating list copies. + if (!needExpansion) + return; + List<Expr*> result; + for (auto expr : args) + { + while (auto paren = as<ParenExpr>(expr)) + expr = paren->base; + auto processExpr = [&]() + { + auto expand = as<ExpandExpr>(expr); + if (!expand) + return false; + auto type = expand->type.type; + if (auto typeType = as<TypeType>(type)) + { + auto typePack = as<ConcreteTypePack>(typeType->getType()); + if (!typePack) + return false; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto expandArg = m_astBuilder->create<SharedTypeExpr>(); + expandArg->loc = expr->loc; + expandArg->type = m_astBuilder->getTypeType(typePack->getElementType(i)); + result.add(expandArg); + } + return true; + } + else if (auto typePack = as<ConcreteTypePack>(type)) + { + auto localScope = getExprLocalScope(); + SLANG_ASSERT(localScope); + + VarDecl* varDecl = m_astBuilder->create<VarDecl>(); + varDecl->parentDecl = nullptr; + if (m_outerScope && m_outerScope->containerDecl) + m_outerScope->containerDecl->addMember(varDecl); + addModifier(varDecl, m_astBuilder->create<LocalTempVarModifier>()); + varDecl->checkState = DeclCheckState::DefinitionChecked; + varDecl->nameAndLoc.loc = expr->loc; + varDecl->initExpr = expr; + varDecl->type.type = expr->type.type; + LetExpr* letExpr = m_astBuilder->create<LetExpr>(); + letExpr->decl = varDecl; + localScope->addBinding(letExpr); + auto varExpr = m_astBuilder->create<VarExpr>(); + varExpr->declRef = varDecl; + varExpr->type = expr->type.type; + varExpr->type.isLeftValue = false; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto expandedArg = m_astBuilder->create<SwizzleExpr>(); + expandedArg->base = varExpr; + expandedArg->type = typePack->getElementType(i); + expandedArg->type.isLeftValue = false; + expandedArg->elementIndices.add((uint32_t)i); + result.add(expandedArg); + } + return true; + } + return false; + }; + + if (!processExpr()) + result.add(expr); + } + args.swapWith(result); +} + bool SemanticsVisitor::OverloadResolveContext::matchArgumentsToParams( SemanticsVisitor* semantics, const List<QualType>& params, @@ -2102,7 +2190,8 @@ bool SemanticsVisitor::OverloadResolveContext::matchArgumentsToParams( } // Try to match the variadic part. - // Is the corresponding argument a expand expr? If so it will map 1:1 to the type pack param. + // Is the corresponding argument a expand expr? If so it will map 1:1 to the type pack + // param. auto astBuilder = semantics->getASTBuilder(); if (remainingArgCount <= 0) @@ -2655,22 +2744,8 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) { return CreateErrorExpr(expr); } - // If any of the arguments is an error, then we should bail out, to avoid - // cascading errors where we successfully pick an overload, but not the one - // the user meant. - for (auto arg : expr->arguments) - { - if (IsErrorExpr(arg)) - return CreateErrorExpr(expr); - // If this argument is itself an overloaded value without a type - // then we can't sensibly continue - if (!arg->type && (as<OverloadedExpr>(arg) || as<OverloadedExpr2>(arg))) - { - getSink()->diagnose(expr->loc, Diagnostics::overloadedParameterToHigherOrderFunction); - return CreateErrorExpr(expr); - } - } + maybeExpandArgList(expr->arguments); for (auto& arg : expr->arguments) { @@ -2700,7 +2775,8 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate)) { // We should only use the cached candidate if it is persistent direct declref - // created from GlobalSession's ASTBuilder, or it is created in the current Linkage. + // created from GlobalSession's ASTBuilder, or it is created in the current + // Linkage. if (candidate.cacheVersion == typeCheckingCache->version || findNextOuterGeneric(candidate.decl) == nullptr) { @@ -2910,8 +2986,8 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr) } } - // Now that we have resolved the overload candidate, we need to undo an `openExistential` - // operation that was applied to `out` arguments. + // Now that we have resolved the overload candidate, we need to undo an + // `openExistential` operation that was applied to `out` arguments. // auto funcType = context.bestCandidate->funcType; ShortList<ParameterDirection> paramDirections; @@ -3087,6 +3163,7 @@ Expr* SemanticsVisitor::checkGenericAppWithCheckedArgs(GenericAppExpr* genericAp auto& baseExpr = genericAppExpr->functionExpr; auto& args = genericAppExpr->arguments; + maybeExpandArgList(args); // If there was an error in the base expression, or in any of // the arguments, then just bail. @@ -3138,6 +3215,7 @@ Expr* SemanticsVisitor::checkGenericAppWithCheckedArgs(GenericAppExpr* genericAp // to complete all of them and create an overloaded expression as a result. auto overloadedExpr = m_astBuilder->create<OverloadedExpr2>(); + overloadedExpr->type = m_astBuilder->getOverloadedType(); overloadedExpr->base = context.baseExpr; for (auto candidate : context.bestCandidates) { diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 284cb634c..6f0a6274e 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -1989,12 +1989,6 @@ DIAGNOSTIC( DIAGNOSTIC( 39999, Error, - overloadedParameterToHigherOrderFunction, - "passing overloaded functions to higher order functions is not supported") - -DIAGNOSTIC( - 39999, - Error, matrixColumnOrRowCountIsOne, "matrices with 1 column or row are not supported by the current code generation target") diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 7f7c39929..7e9740b53 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -118,10 +118,6 @@ public: bool hasSeenCompletionToken = false; - // Track whether or not we are inside a generics that has variadic parameters. - // If so we will enable the new `expand` and `each` keyword. - bool isInVariadicGenerics = false; - TokenReader tokenReader; DiagnosticSink* sink; SourceLoc lastErrorLoc; @@ -1619,9 +1615,6 @@ static void ParseGenericDeclImpl(Parser* parser, GenericDecl* decl, const TFunc& { parser->ReadToken(TokenType::OpLess); parser->genericDepth++; - bool oldIsInVariadicGenerics = parser->isInVariadicGenerics; - SLANG_DEFER(parser->isInVariadicGenerics = oldIsInVariadicGenerics); - for (;;) { const TokenType tokenType = parser->tokenReader.peekTokenType(); @@ -1635,11 +1628,6 @@ static void ParseGenericDeclImpl(Parser* parser, GenericDecl* decl, const TFunc& auto genericParam = ParseGenericParamDecl(parser, decl); AddMember(decl, genericParam); - if (as<GenericTypePackParamDecl>(genericParam)) - { - parser->isInVariadicGenerics = true; - } - // Make sure we make forward progress. if (parser->tokenReader.getCursor() == currentCursor) advanceToken(parser); @@ -1869,7 +1857,6 @@ static Stmt* parseOptBody(Parser* parser) unparsedStmt->currentScope = parser->currentScope; unparsedStmt->outerScope = parser->outerScope; unparsedStmt->sourceLanguage = parser->getSourceLanguage(); - unparsedStmt->isInVariadicGenerics = parser->isInVariadicGenerics; parser->FillPosition(unparsedStmt); List<Token>& tokens = unparsedStmt->tokens; int braceDepth = 0; @@ -8616,6 +8603,32 @@ static Expr* parseEachExpr(Parser* parser, SourceLoc loc) return eachExpr; } +// Check if a specific contextual keyword is available and not shadowed by user-defined decls. +static bool isKeywordAvailable(Parser* parser, const char* keyword) +{ + if (!parser->semanticsVisitor || !parser->currentLookupScope) + return true; + if (lookUp( + parser->astBuilder, + parser->semanticsVisitor, + parser->semanticsVisitor->getName(keyword), + parser->currentLookupScope) + .isValid()) + return false; + return true; +} + +// Advance the token reader if the next token is a keyword and the keyword is not shadowed. +static bool advanceIfAvailableKeyword(Parser* parser, const char* keyword) +{ + if (parser->LookAheadToken(keyword) && isKeywordAvailable(parser, keyword)) + { + parser->ReadToken(); + return true; + } + return false; +} + static Expr* parsePrefixExpr(Parser* parser) { auto tokenType = peekTokenType(parser); @@ -8650,18 +8663,13 @@ static Expr* parsePrefixExpr(Parser* parser) { return parseSPIRVAsmExpr(parser, tokenLoc); } - else if (parser->isInVariadicGenerics) + else if (advanceIfAvailableKeyword(parser, "expand")) { - // If we are inside a variadic generic, we also need to recognize - // the new `expand` and `each` keyword for dealing with variadic packs. - if (AdvanceIf(parser, "expand")) - { - return parseExpandExpr(parser, tokenLoc); - } - else if (AdvanceIf(parser, "each")) - { - return parseEachExpr(parser, tokenLoc); - } + return parseExpandExpr(parser, tokenLoc); + } + else if (advanceIfAvailableKeyword(parser, "each")) + { + return parseEachExpr(parser, tokenLoc); } return parsePostfixExpr(parser); } @@ -8795,7 +8803,6 @@ Stmt* parseUnparsedStmt( SemanticsVisitor* semanticsVisitor, TranslationUnitRequest* translationUnit, SourceLanguage sourceLanguage, - bool isInVariadicGenerics, TokenSpan const& tokens, DiagnosticSink* sink, Scope* currentScope, @@ -8819,7 +8826,6 @@ Stmt* parseUnparsedStmt( parser.semanticsVisitor = semanticsVisitor; parser.currentScope = parser.currentLookupScope = currentScope; parser.currentModule = semanticsVisitor->getShared()->getModule()->getModuleDecl(); - parser.isInVariadicGenerics = isInVariadicGenerics; return parser.parseBlockStatement(); } @@ -9388,7 +9394,8 @@ static NodeBase* parseMagicTypeModifier(Parser* parser, void* /*userData*/) { modifier->magicNodeType = syntaxClass; } - // TODO: print diagnostic if the magic type name doesn't correspond to an actual ASTNodeType. + // TODO: print diagnostic if the magic type name doesn't correspond to an actual + // ASTNodeType. parser->ReadToken(TokenType::RParent); return modifier; diff --git a/source/slang/slang-parser.h b/source/slang/slang-parser.h index c4e68a7fa..5ca1b9f69 100644 --- a/source/slang/slang-parser.h +++ b/source/slang/slang-parser.h @@ -32,7 +32,6 @@ Stmt* parseUnparsedStmt( SemanticsVisitor* semantics, TranslationUnitRequest* translationUnit, SourceLanguage sourceLanguage, - bool isInVariadicGenerics, TokenSpan const& tokens, DiagnosticSink* sink, Scope* currentScope, diff --git a/tests/language-feature/higher-order-functions/overloaded.slang b/tests/language-feature/higher-order-functions/overloaded.slang index ebbc6ed19..d3d495bc1 100644 --- a/tests/language-feature/higher-order-functions/overloaded.slang +++ b/tests/language-feature/higher-order-functions/overloaded.slang @@ -1,11 +1,11 @@ -//TEST:SIMPLE: +//TEST:INTERPRET(filecheck=CHECK): func foo(f : functype (float) -> int) -> int{ return f(0); } int bit<T>(T) { - return 1; + return 10; } int bit<T, let N : int>(vector<T, N>) { @@ -13,6 +13,13 @@ int bit<T, let N : int>(vector<T, N>) { } int zit() { - // In an ideal world in this case we could infer that we want bit<T> + // even though foo is overloaded, we should still be able to infer that we want bit<T> + // based on the parameter (expected) type. return foo(bit<float>); } + +void main() +{ + // CHECK: 10 + printf("%d\n", zit()); +}
\ No newline at end of file diff --git a/tests/language-feature/higher-order-functions/overloaded.slang.expected b/tests/language-feature/higher-order-functions/overloaded.slang.expected deleted file mode 100644 index 3d02bf06f..000000000 --- a/tests/language-feature/higher-order-functions/overloaded.slang.expected +++ /dev/null @@ -1,8 +0,0 @@ -result code = -1 -standard error = { -tests/language-feature/higher-order-functions/overloaded.slang(17): error 39999: passing overloaded functions to higher order functions is not supported - return foo(bit<float>); - ^ -} -standard output = { -} diff --git a/tests/language-feature/tuple/tuple-expand-call.slang b/tests/language-feature/tuple/tuple-expand-call.slang new file mode 100644 index 000000000..f2293f03f --- /dev/null +++ b/tests/language-feature/tuple/tuple-expand-call.slang @@ -0,0 +1,26 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type + +//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0], stride=4) + +#lang 2026 + +RWStructuredBuffer<int> outputBuffer; + +int f(int x, float y) { return x + int(y); } + +int g<each T>(expand each Tuple<expand each T> t) +{ + return countof(t); +} + +[numthreads(1,1,1)] +void computeMain() +{ + Tuple<expand each Tuple<int, float>> x = (2, 3.0f); + + // CHECK: 2 + outputBuffer[0] = g<int, float>(expand each x); + + // CHECK: 5 + outputBuffer[1] = f(expand each x); +} |
