From 071f1b6062b459928ebfd6f2f60a8d6ad021112b Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 14 Aug 2024 18:41:48 -0700 Subject: Variadic Generics Part 1: parsing and type checking. (#4833) --- source/slang/slang-check-overload.cpp | 401 +++++++++++++++++++++++++++++----- 1 file changed, 346 insertions(+), 55 deletions(-) (limited to 'source/slang/slang-check-overload.cpp') diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index d37c6e469..16f6ed7da 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -15,22 +15,37 @@ namespace Slang ParamCounts counts = { 0, 0 }; for (auto param : params) { - counts.allowed++; - - // No initializer means no default value - // - // TODO(tfoley): The logic here is currently broken in two ways: - // - // 1. We are assuming that once one parameter has a default, then all do. - // This can/should be validated earlier, so that we can assume it here. - // - // 2. We are not handling the possibility of multiple declarations for - // a single function, where we'd need to merge default parameters across - // all the declarations. - if (!param.getDecl()->initExpr) + Index allowedArgCountToAdd = 1; + auto paramType = getParamType(m_astBuilder, param); + if (isTypePack(paramType)) { + if (auto typePack = as(paramType)) + { + counts.required += typePack->getTypeCount(); + allowedArgCountToAdd = typePack->getTypeCount(); + } + else + { + counts.allowed = -1; + } + } + else if (!param.getDecl()->initExpr) + { + // No initializer means no default value + // + // TODO(tfoley): The logic here is currently broken in two ways: + // + // 1. We are assuming that once one parameter has a default, then all do. + // This can/should be validated earlier, so that we can assume it here. + // + // 2. We are not handling the possibility of multiple declarations for + // a single function, where we'd need to merge default parameters across + // all the declarations. counts.required++; } + + if (counts.allowed >= 0) + counts.allowed += allowedArgCountToAdd; } return counts; } @@ -42,7 +57,8 @@ namespace Slang { if (auto typeParam = as(m)) { - counts.allowed++; + if (counts.allowed >= 0) + counts.allowed++; if (!typeParam->initType.Ptr()) { counts.required++; @@ -50,12 +66,17 @@ namespace Slang } else if (auto valParam = as(m)) { - counts.allowed++; + if (counts.allowed >= 0) + counts.allowed++; if (!valParam->initExpr) { counts.required++; } } + else if (as(m)) + { + counts.allowed = -1; + } } return counts; } @@ -130,7 +151,7 @@ namespace Slang break; } - if (argCount >= paramCounts.required && argCount <= paramCounts.allowed) + if (argCount >= paramCounts.required && (paramCounts.allowed == -1 || argCount <= paramCounts.allowed)) return true; // Emit an error message if we are checking this call for real @@ -270,13 +291,35 @@ namespace Slang getSink()->diagnose(context.loc, Diagnostics::cannotSpecializeGeneric, candidate.item.declRef); } }; + List paramTypes; + for (auto memberRef : getMembers(m_astBuilder, genericDeclRef)) + { + if (auto typeParamRef = memberRef.as()) + { + paramTypes.add(DeclRefType::create(m_astBuilder, typeParamRef)); + } + else if (auto valParamRef = memberRef.as()) + { + paramTypes.add(getType(m_astBuilder, valParamRef)); + } + else if (auto typePackParam = memberRef.as()) + { + paramTypes.add(DeclRefType::create(m_astBuilder, typePackParam)); + } + } + ShortList matchedArgs; + if (!context.matchArgumentsToParams(this, paramTypes, false, matchedArgs)) + { + maybeReportGeneralError(); + return false; + } Index aa = 0; for (auto memberRef : getMembers(m_astBuilder, genericDeclRef)) { if (auto typeParamRef = memberRef.as()) { - if (aa >= context.argCount) + if (aa >= matchedArgs.getCount()) { if (allowPartialGenericApp) { @@ -319,15 +362,15 @@ namespace Slang // the checking "for real" in which case any errors // we run into need to be reported. // - auto arg = context.getArg(aa++); + auto arg = matchedArgs[aa++]; if (context.mode == OverloadResolveContext::Mode::JustTrying) { - typeArg = tryCoerceToProperType(TypeExp(arg)); + typeArg = tryCoerceToProperType(TypeExp(arg.argExpr)); } else { - arg = ExpectATypeRepr(arg); - typeArg = CoerceToProperType(TypeExp(arg)); + arg.argExpr = ExpectATypeRepr(arg.argExpr); + typeArg = CoerceToProperType(TypeExp(arg.argExpr)); } // If we failed to get a valid type (either because @@ -345,7 +388,7 @@ namespace Slang } else if (auto valParamRef = memberRef.as()) { - if (aa >= context.argCount) + if (aa >= matchedArgs.getCount()) { if (allowPartialGenericApp) { @@ -384,7 +427,7 @@ namespace Slang // to the type of the parameter (and fail if the // coercion is not possible) // - arg = context.getArg(aa++); + arg = matchedArgs[aa++].argExpr; if (context.mode == OverloadResolveContext::Mode::JustTrying) { ConversionCost cost = kConversionCost_None; @@ -418,6 +461,78 @@ namespace Slang } checkedArgs.add(val); } + else if (auto typePackParam = memberRef.as()) + { + Val* val = nullptr; + if (aa >= matchedArgs.getCount()) + { + // If we run out of matched args, we will just create an empty pack. + val = m_astBuilder->getTypePack(ArrayView()); + } + else + { + auto matchedArg = matchedArgs[aa++]; + if (auto packExpr = as(matchedArg.argExpr)) + { + // We are providing a concrete pack of types as arguments to a type pack parameter. + // We need to create a `TypePack` type to serve as the argument. + ShortList coercedProperTypes; + + // Coerce all types in the pack to proper types. + for (Index i = 0; i < packExpr->args.getCount(); i++) + { + TypeExp typeArg; + auto elementTypeExpr = packExpr->args[i]; + if (context.mode == OverloadResolveContext::Mode::JustTrying) + { + typeArg = tryCoerceToProperType(TypeExp(elementTypeExpr)); + if (!typeArg.type) + { + typeArg.type = m_astBuilder->getErrorType(); + success = false; + } + } + else + { + elementTypeExpr = ExpectATypeRepr(elementTypeExpr); + typeArg = CoerceToProperType(TypeExp(elementTypeExpr)); + } + // If we failed to get a valid type (either because + // there was no matching argument, or because the + // "just trying" coercion failed), then we create + // an error type to stand in for the argument + // + if (!typeArg.type) + { + typeArg.type = m_astBuilder->getErrorType(); + success = false; + } + coercedProperTypes.add(typeArg.type); + } + val = m_astBuilder->getTypePack(coercedProperTypes.getArrayView().arrayView); + } + else if (auto expandExpr = as(matchedArg.argExpr)) + { + auto argType = expandExpr->type.type; + if (auto typeType = as(argType)) + argType = typeType->getType(); + val = argType; + } + else if (auto typeType = as(matchedArg.argType)) + { + if (isAbstractTypePack(typeType->getType())) + { + val = typeType->getType(); + } + } + } + if (val == nullptr) + { + maybeReportGeneralError(); + return false; + } + checkedArgs.add(val); + } else { continue; @@ -497,37 +612,104 @@ namespace Slang break; } - // Note(tfoley): We might have fewer arguments than parameters in the - // case where one or more parameters had defaults. - SLANG_RELEASE_ASSERT(argCount <= paramTypes.getCount()); + Index paramIndex = 0; + Index argIndex = 0; + struct Arg { Expr* argExpr; Type* type; }; + auto readArg = [&]() -> Arg + { + if (argIndex >= argCount) + return { nullptr, nullptr }; + auto arg = context.getArg(argIndex); + Arg result = { arg, context.getArgType(argIndex) }; + argIndex++; + return result; + }; - for (Index ii = 0; ii < argCount; ++ii) + auto coerceArgToParam = [&](Arg arg, QualType paramType) -> Arg + { + auto argType = QualType(arg.type, paramType.isLeftValue); + if (!paramType) + return { nullptr, nullptr }; + if (!argType) + return { nullptr, nullptr }; + if (context.mode == OverloadResolveContext::Mode::JustTrying) + { + ConversionCost cost = kConversionCost_None; + if (context.disallowNestedConversions) + { + // We need an exact match in this case. + if (!paramType->equals(argType)) + return { nullptr, nullptr }; + } + else if (!canCoerce(paramType, argType, arg.argExpr, &cost)) + { + return { nullptr, nullptr }; + } + candidate.conversionCostSum += cost; + } + else + { + arg.argExpr = coerce(CoercionSite::Argument, paramType, arg.argExpr); + } + return arg; + }; + ShortList resultArgs; + + while (paramIndex < paramTypes.getCount()) { - auto& arg = context.getArg(ii); - auto paramType = paramTypes[ii]; - auto argType = QualType(context.getArgType(ii), paramType.isLeftValue); - if (!paramType) - return false; - if (!argType) - return false; - if (context.mode == OverloadResolveContext::Mode::JustTrying) + auto paramType = paramTypes[paramIndex]; + if (auto paramTypePack = as(paramType)) { - ConversionCost cost = kConversionCost_None; - if( context.disallowNestedConversions ) + ShortList innerArgs; + for (Index i = 0; i < paramTypePack->getTypeCount(); i++) { - // We need an exact match in this case. - if(!paramType->equals(argType)) + auto arg = readArg(); + auto coercedArg = coerceArgToParam(arg, QualType(paramTypePack->getElementType(i), paramType.isLeftValue)); + if (!coercedArg.type) + { return false; + } + if (context.mode == OverloadResolveContext::Mode::ForReal) + innerArgs.add(coercedArg.argExpr); } - else if (!canCoerce(paramType, argType, arg, &cost)) + if (context.mode == OverloadResolveContext::Mode::ForReal) { - return false; + auto packArg = m_astBuilder->create(); + for (auto aa : innerArgs) + packArg->args.add(aa); + packArg->type = paramType; + resultArgs.add(packArg); } - candidate.conversionCostSum += cost; } else { - arg = coerce(CoercionSite::Argument, paramType, arg); + auto arg = readArg(); + if (!arg.type) + { + // If we run out of arguments, we can exit the loop now. + // Note that in this type we don't need to worry about + // default arguments, because we already checked that + // the number of arguments was correct in `TryCheckOverloadCandidateArity`. + break; + } + auto coercedArg = coerceArgToParam(arg, paramType); + if (!coercedArg.type) + { + return false; + } + if (context.mode == OverloadResolveContext::Mode::ForReal) + resultArgs.add(coercedArg.argExpr); + } + paramIndex++; + } + if (context.mode == OverloadResolveContext::Mode::ForReal) + { + context.argCount = resultArgs.getCount(); + if (context.args) + { + context.args->setCount(context.argCount); + for (Index i = 0; i < context.argCount; i++) + (*context.args)[i] = resultArgs[i]; } } return true; @@ -1448,6 +1630,110 @@ namespace Slang AddOverloadCandidate(context, candidate, baseCost); } + bool SemanticsVisitor::OverloadResolveContext::matchArgumentsToParams( + SemanticsVisitor* semantics, + const List& params, + bool computeTypes, + ShortList& outMatchedArgs) + { + // We allow params to end with one or more variadic packs. + // We will first find out how many type packs there are. + Index typePackCount = 0; + for (Index i = params.getCount() - 1; i >= 0; --i) + { + if (isTypePack(params[i].type)) + typePackCount++; + else + break; + } + auto fixedParamCount = params.getCount() - typePackCount; + + auto remainingArgCount = getArgCount() - fixedParamCount; + + // If there are remaining arguments after matching all fixed parameters, + // we'd better have at least one type pack. + if (remainingArgCount > 0 && typePackCount == 0) + return false; + + // Now we can match the arguments to the parameters. + + // The fixed part comes first. + for (Index i = 0; i < Math::Min(getArgCount(), fixedParamCount); ++i) + { + MatchedArg arg; + arg.argExpr = getArg(i); + arg.argType = getArgType(i); + outMatchedArgs.add(arg); + } + + // Try to match the variadic part. + // Is the corresponding argument a expand expr? If so it will map 1:1 to the type pack param. + auto astBuilder = semantics->getASTBuilder(); + while (remainingArgCount > 0) + { + auto argType = getArgType(fixedParamCount); + if (auto typeType = as(argType)) + { + argType = typeType->getType(); + } + if (isAbstractTypePack(argType)) + { + MatchedArg arg; + arg.argExpr = getArg(fixedParamCount); + arg.argType = getArgType(fixedParamCount); + outMatchedArgs.add(arg); + fixedParamCount++; + remainingArgCount--; + typePackCount--; + continue; + } + break; + } + + if (remainingArgCount <= 0) + return true; + if (typePackCount == 0) + return false; + + // If the number of type packs can't evenly divide the remaining arguments, + // there isn't a match. + if (remainingArgCount % typePackCount != 0) + return false; + + // The default case is to group the remaining arguments into evenly divided PackExprs. + Index typePackSize = remainingArgCount / typePackCount; + for (Index i = 0; i < typePackCount; ++i) + { + PackExpr* packExpr = nullptr; + if (mode == Mode::ForReal) + { + packExpr = astBuilder->create(); + packExpr->loc = loc; + } + ShortList types; + for (Index j = 0; j < typePackSize; ++j) + { + if (packExpr) + { + auto arg = getArg(fixedParamCount + i * typePackSize + j); + packExpr->args.add(arg); + } + if (computeTypes) + types.add(getArgTypeForInference(fixedParamCount + i * typePackSize + j, semantics)); + } + MatchedArg matchedArg; + matchedArg.argExpr = packExpr; + if (computeTypes) + { + matchedArg.argType = astBuilder->getTypePack(types.getArrayView().arrayView); + if (packExpr) + packExpr->type = matchedArg.argType; + } + outMatchedArgs.add(matchedArg); + } + return true; + } + DeclRef SemanticsVisitor::inferGenericArguments( DeclRef genericDeclRef, OverloadResolveContext& context, @@ -1506,25 +1792,23 @@ namespace Slang innerParameterTypes = ¶mTypes; } - Index valueArgCount = context.getArgCount(); - Index valueParamCount = innerParameterTypes->getCount(); + ShortList matchedArgs; - // If there are too many arguments, we cannot possibly have a match. + // We now try to match arguments to parameters. // // Note that if there are *too few* arguments, we might still have // a match, because the other arguments might have default values // that can be used. // - if (valueArgCount > valueParamCount) + if (!context.matchArgumentsToParams(this, *innerParameterTypes, true, matchedArgs)) { return DeclRef(); } - // If any of the arguments were specified explicitly (and are thus known), - // we do not want to take them into account during the unification and - // constraint generation step. + // Perform type unification between arguments and parameters, so + // we can populate the resolve system with inital constraints. // - for (Index aa = 0; aa < valueArgCount; ++aa) + for (Index aa = 0; aa < matchedArgs.getCount(); ++aa) { // The question here is whether failure to "unify" an argument // and parameter should lead to immediate failure. @@ -1543,12 +1827,19 @@ namespace Slang // // So the question is then whether a mismatch during the // unification step should be taken as an immediate failure... - auto argType = context.getArgTypeForInference(aa, this); + auto argType = matchedArgs[aa].argType; auto paramType = (*innerParameterTypes)[aa]; - TryUnifyTypes( + auto canUnify = TryUnifyTypes( constraints, + ValUnificationContext(), QualType(argType, paramType.isLeftValue), paramType); + + // It is an error if we can't unify the argument with a type pack parameter. + if (!canUnify && isTypePack(paramType)) + { + return DeclRef(); + } } } else @@ -1984,7 +2275,7 @@ namespace Slang context.originalExpr = expr; context.funcLoc = funcExpr->loc; context.argCount = expr->arguments.getCount(); - context.args = expr->arguments.getBuffer(); + context.args = &expr->arguments; context.loc = expr->loc; context.sourceScope = m_outerScope; context.baseExpr = GetBaseExpr(funcExpr); @@ -2238,7 +2529,7 @@ namespace Slang context.originalExpr = genericAppExpr; context.funcLoc = baseExpr->loc; context.argCount = args.getCount(); - context.args = args.getBuffer(); + context.args = &args; context.loc = genericAppExpr->loc; context.sourceScope = m_outerScope; context.baseExpr = GetBaseExpr(baseExpr); -- cgit v1.2.3