summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-overload.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-14 18:41:48 -0700
committerGitHub <noreply@github.com>2024-08-14 18:41:48 -0700
commit071f1b6062b459928ebfd6f2f60a8d6ad021112b (patch)
tree2ba65eb40f39701db6fc775a9258ec8079d161a0 /source/slang/slang-check-overload.cpp
parent35a3d32c87f079749f6b100d01b289c3da02d7d6 (diff)
Variadic Generics Part 1: parsing and type checking. (#4833)
Diffstat (limited to 'source/slang/slang-check-overload.cpp')
-rw-r--r--source/slang/slang-check-overload.cpp401
1 files changed, 346 insertions, 55 deletions
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index d37c6e469..16f6ed7da 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -15,22 +15,37 @@ namespace Slang
ParamCounts counts = { 0, 0 };
for (auto param : params)
{
- counts.allowed++;
-
- // No initializer means no default value
- //
- // TODO(tfoley): The logic here is currently broken in two ways:
- //
- // 1. We are assuming that once one parameter has a default, then all do.
- // This can/should be validated earlier, so that we can assume it here.
- //
- // 2. We are not handling the possibility of multiple declarations for
- // a single function, where we'd need to merge default parameters across
- // all the declarations.
- if (!param.getDecl()->initExpr)
+ Index allowedArgCountToAdd = 1;
+ auto paramType = getParamType(m_astBuilder, param);
+ if (isTypePack(paramType))
{
+ if (auto typePack = as<ConcreteTypePack>(paramType))
+ {
+ counts.required += typePack->getTypeCount();
+ allowedArgCountToAdd = typePack->getTypeCount();
+ }
+ else
+ {
+ counts.allowed = -1;
+ }
+ }
+ else if (!param.getDecl()->initExpr)
+ {
+ // No initializer means no default value
+ //
+ // TODO(tfoley): The logic here is currently broken in two ways:
+ //
+ // 1. We are assuming that once one parameter has a default, then all do.
+ // This can/should be validated earlier, so that we can assume it here.
+ //
+ // 2. We are not handling the possibility of multiple declarations for
+ // a single function, where we'd need to merge default parameters across
+ // all the declarations.
counts.required++;
}
+
+ if (counts.allowed >= 0)
+ counts.allowed += allowedArgCountToAdd;
}
return counts;
}
@@ -42,7 +57,8 @@ namespace Slang
{
if (auto typeParam = as<GenericTypeParamDecl>(m))
{
- counts.allowed++;
+ if (counts.allowed >= 0)
+ counts.allowed++;
if (!typeParam->initType.Ptr())
{
counts.required++;
@@ -50,12 +66,17 @@ namespace Slang
}
else if (auto valParam = as<GenericValueParamDecl>(m))
{
- counts.allowed++;
+ if (counts.allowed >= 0)
+ counts.allowed++;
if (!valParam->initExpr)
{
counts.required++;
}
}
+ else if (as<GenericTypePackParamDecl>(m))
+ {
+ counts.allowed = -1;
+ }
}
return counts;
}
@@ -130,7 +151,7 @@ namespace Slang
break;
}
- if (argCount >= paramCounts.required && argCount <= paramCounts.allowed)
+ if (argCount >= paramCounts.required && (paramCounts.allowed == -1 || argCount <= paramCounts.allowed))
return true;
// Emit an error message if we are checking this call for real
@@ -270,13 +291,35 @@ namespace Slang
getSink()->diagnose(context.loc, Diagnostics::cannotSpecializeGeneric, candidate.item.declRef);
}
};
+ List<QualType> paramTypes;
+ for (auto memberRef : getMembers(m_astBuilder, genericDeclRef))
+ {
+ if (auto typeParamRef = memberRef.as<GenericTypeParamDecl>())
+ {
+ paramTypes.add(DeclRefType::create(m_astBuilder, typeParamRef));
+ }
+ else if (auto valParamRef = memberRef.as<GenericValueParamDecl>())
+ {
+ paramTypes.add(getType(m_astBuilder, valParamRef));
+ }
+ else if (auto typePackParam = memberRef.as<GenericTypePackParamDecl>())
+ {
+ paramTypes.add(DeclRefType::create(m_astBuilder, typePackParam));
+ }
+ }
+ ShortList<OverloadResolveContext::MatchedArg> matchedArgs;
+ if (!context.matchArgumentsToParams(this, paramTypes, false, matchedArgs))
+ {
+ maybeReportGeneralError();
+ return false;
+ }
Index aa = 0;
for (auto memberRef : getMembers(m_astBuilder, genericDeclRef))
{
if (auto typeParamRef = memberRef.as<GenericTypeParamDecl>())
{
- if (aa >= context.argCount)
+ if (aa >= matchedArgs.getCount())
{
if (allowPartialGenericApp)
{
@@ -319,15 +362,15 @@ namespace Slang
// the checking "for real" in which case any errors
// we run into need to be reported.
//
- auto arg = context.getArg(aa++);
+ auto arg = matchedArgs[aa++];
if (context.mode == OverloadResolveContext::Mode::JustTrying)
{
- typeArg = tryCoerceToProperType(TypeExp(arg));
+ typeArg = tryCoerceToProperType(TypeExp(arg.argExpr));
}
else
{
- arg = ExpectATypeRepr(arg);
- typeArg = CoerceToProperType(TypeExp(arg));
+ arg.argExpr = ExpectATypeRepr(arg.argExpr);
+ typeArg = CoerceToProperType(TypeExp(arg.argExpr));
}
// If we failed to get a valid type (either because
@@ -345,7 +388,7 @@ namespace Slang
}
else if (auto valParamRef = memberRef.as<GenericValueParamDecl>())
{
- if (aa >= context.argCount)
+ if (aa >= matchedArgs.getCount())
{
if (allowPartialGenericApp)
{
@@ -384,7 +427,7 @@ namespace Slang
// to the type of the parameter (and fail if the
// coercion is not possible)
//
- arg = context.getArg(aa++);
+ arg = matchedArgs[aa++].argExpr;
if (context.mode == OverloadResolveContext::Mode::JustTrying)
{
ConversionCost cost = kConversionCost_None;
@@ -418,6 +461,78 @@ namespace Slang
}
checkedArgs.add(val);
}
+ else if (auto typePackParam = memberRef.as<GenericTypePackParamDecl>())
+ {
+ Val* val = nullptr;
+ if (aa >= matchedArgs.getCount())
+ {
+ // If we run out of matched args, we will just create an empty pack.
+ val = m_astBuilder->getTypePack(ArrayView<Type*>());
+ }
+ else
+ {
+ auto matchedArg = matchedArgs[aa++];
+ if (auto packExpr = as<PackExpr>(matchedArg.argExpr))
+ {
+ // We are providing a concrete pack of types as arguments to a type pack parameter.
+ // We need to create a `TypePack` type to serve as the argument.
+ ShortList<Type*> coercedProperTypes;
+
+ // Coerce all types in the pack to proper types.
+ for (Index i = 0; i < packExpr->args.getCount(); i++)
+ {
+ TypeExp typeArg;
+ auto elementTypeExpr = packExpr->args[i];
+ if (context.mode == OverloadResolveContext::Mode::JustTrying)
+ {
+ typeArg = tryCoerceToProperType(TypeExp(elementTypeExpr));
+ if (!typeArg.type)
+ {
+ typeArg.type = m_astBuilder->getErrorType();
+ success = false;
+ }
+ }
+ else
+ {
+ elementTypeExpr = ExpectATypeRepr(elementTypeExpr);
+ typeArg = CoerceToProperType(TypeExp(elementTypeExpr));
+ }
+ // If we failed to get a valid type (either because
+ // there was no matching argument, or because the
+ // "just trying" coercion failed), then we create
+ // an error type to stand in for the argument
+ //
+ if (!typeArg.type)
+ {
+ typeArg.type = m_astBuilder->getErrorType();
+ success = false;
+ }
+ coercedProperTypes.add(typeArg.type);
+ }
+ val = m_astBuilder->getTypePack(coercedProperTypes.getArrayView().arrayView);
+ }
+ else if (auto expandExpr = as<ExpandExpr>(matchedArg.argExpr))
+ {
+ auto argType = expandExpr->type.type;
+ if (auto typeType = as<TypeType>(argType))
+ argType = typeType->getType();
+ val = argType;
+ }
+ else if (auto typeType = as<TypeType>(matchedArg.argType))
+ {
+ if (isAbstractTypePack(typeType->getType()))
+ {
+ val = typeType->getType();
+ }
+ }
+ }
+ if (val == nullptr)
+ {
+ maybeReportGeneralError();
+ return false;
+ }
+ checkedArgs.add(val);
+ }
else
{
continue;
@@ -497,37 +612,104 @@ namespace Slang
break;
}
- // Note(tfoley): We might have fewer arguments than parameters in the
- // case where one or more parameters had defaults.
- SLANG_RELEASE_ASSERT(argCount <= paramTypes.getCount());
+ Index paramIndex = 0;
+ Index argIndex = 0;
+ struct Arg { Expr* argExpr; Type* type; };
+ auto readArg = [&]() -> Arg
+ {
+ if (argIndex >= argCount)
+ return { nullptr, nullptr };
+ auto arg = context.getArg(argIndex);
+ Arg result = { arg, context.getArgType(argIndex) };
+ argIndex++;
+ return result;
+ };
- for (Index ii = 0; ii < argCount; ++ii)
+ auto coerceArgToParam = [&](Arg arg, QualType paramType) -> Arg
+ {
+ auto argType = QualType(arg.type, paramType.isLeftValue);
+ if (!paramType)
+ return { nullptr, nullptr };
+ if (!argType)
+ return { nullptr, nullptr };
+ if (context.mode == OverloadResolveContext::Mode::JustTrying)
+ {
+ ConversionCost cost = kConversionCost_None;
+ if (context.disallowNestedConversions)
+ {
+ // We need an exact match in this case.
+ if (!paramType->equals(argType))
+ return { nullptr, nullptr };
+ }
+ else if (!canCoerce(paramType, argType, arg.argExpr, &cost))
+ {
+ return { nullptr, nullptr };
+ }
+ candidate.conversionCostSum += cost;
+ }
+ else
+ {
+ arg.argExpr = coerce(CoercionSite::Argument, paramType, arg.argExpr);
+ }
+ return arg;
+ };
+ ShortList<Expr*> resultArgs;
+
+ while (paramIndex < paramTypes.getCount())
{
- auto& arg = context.getArg(ii);
- auto paramType = paramTypes[ii];
- auto argType = QualType(context.getArgType(ii), paramType.isLeftValue);
- if (!paramType)
- return false;
- if (!argType)
- return false;
- if (context.mode == OverloadResolveContext::Mode::JustTrying)
+ auto paramType = paramTypes[paramIndex];
+ if (auto paramTypePack = as<ConcreteTypePack>(paramType))
{
- ConversionCost cost = kConversionCost_None;
- if( context.disallowNestedConversions )
+ ShortList<Expr*> innerArgs;
+ for (Index i = 0; i < paramTypePack->getTypeCount(); i++)
{
- // We need an exact match in this case.
- if(!paramType->equals(argType))
+ auto arg = readArg();
+ auto coercedArg = coerceArgToParam(arg, QualType(paramTypePack->getElementType(i), paramType.isLeftValue));
+ if (!coercedArg.type)
+ {
return false;
+ }
+ if (context.mode == OverloadResolveContext::Mode::ForReal)
+ innerArgs.add(coercedArg.argExpr);
}
- else if (!canCoerce(paramType, argType, arg, &cost))
+ if (context.mode == OverloadResolveContext::Mode::ForReal)
{
- return false;
+ auto packArg = m_astBuilder->create<PackExpr>();
+ for (auto aa : innerArgs)
+ packArg->args.add(aa);
+ packArg->type = paramType;
+ resultArgs.add(packArg);
}
- candidate.conversionCostSum += cost;
}
else
{
- arg = coerce(CoercionSite::Argument, paramType, arg);
+ auto arg = readArg();
+ if (!arg.type)
+ {
+ // If we run out of arguments, we can exit the loop now.
+ // Note that in this type we don't need to worry about
+ // default arguments, because we already checked that
+ // the number of arguments was correct in `TryCheckOverloadCandidateArity`.
+ break;
+ }
+ auto coercedArg = coerceArgToParam(arg, paramType);
+ if (!coercedArg.type)
+ {
+ return false;
+ }
+ if (context.mode == OverloadResolveContext::Mode::ForReal)
+ resultArgs.add(coercedArg.argExpr);
+ }
+ paramIndex++;
+ }
+ if (context.mode == OverloadResolveContext::Mode::ForReal)
+ {
+ context.argCount = resultArgs.getCount();
+ if (context.args)
+ {
+ context.args->setCount(context.argCount);
+ for (Index i = 0; i < context.argCount; i++)
+ (*context.args)[i] = resultArgs[i];
}
}
return true;
@@ -1448,6 +1630,110 @@ namespace Slang
AddOverloadCandidate(context, candidate, baseCost);
}
+ bool SemanticsVisitor::OverloadResolveContext::matchArgumentsToParams(
+ SemanticsVisitor* semantics,
+ const List<QualType>& params,
+ bool computeTypes,
+ ShortList<MatchedArg>& outMatchedArgs)
+ {
+ // We allow params to end with one or more variadic packs.
+ // We will first find out how many type packs there are.
+ Index typePackCount = 0;
+ for (Index i = params.getCount() - 1; i >= 0; --i)
+ {
+ if (isTypePack(params[i].type))
+ typePackCount++;
+ else
+ break;
+ }
+ auto fixedParamCount = params.getCount() - typePackCount;
+
+ auto remainingArgCount = getArgCount() - fixedParamCount;
+
+ // If there are remaining arguments after matching all fixed parameters,
+ // we'd better have at least one type pack.
+ if (remainingArgCount > 0 && typePackCount == 0)
+ return false;
+
+ // Now we can match the arguments to the parameters.
+
+ // The fixed part comes first.
+ for (Index i = 0; i < Math::Min(getArgCount(), fixedParamCount); ++i)
+ {
+ MatchedArg arg;
+ arg.argExpr = getArg(i);
+ arg.argType = getArgType(i);
+ outMatchedArgs.add(arg);
+ }
+
+ // Try to match the variadic part.
+ // Is the corresponding argument a expand expr? If so it will map 1:1 to the type pack param.
+ auto astBuilder = semantics->getASTBuilder();
+ while (remainingArgCount > 0)
+ {
+ auto argType = getArgType(fixedParamCount);
+ if (auto typeType = as<TypeType>(argType))
+ {
+ argType = typeType->getType();
+ }
+ if (isAbstractTypePack(argType))
+ {
+ MatchedArg arg;
+ arg.argExpr = getArg(fixedParamCount);
+ arg.argType = getArgType(fixedParamCount);
+ outMatchedArgs.add(arg);
+ fixedParamCount++;
+ remainingArgCount--;
+ typePackCount--;
+ continue;
+ }
+ break;
+ }
+
+ if (remainingArgCount <= 0)
+ return true;
+ if (typePackCount == 0)
+ return false;
+
+ // If the number of type packs can't evenly divide the remaining arguments,
+ // there isn't a match.
+ if (remainingArgCount % typePackCount != 0)
+ return false;
+
+ // The default case is to group the remaining arguments into evenly divided PackExprs.
+ Index typePackSize = remainingArgCount / typePackCount;
+ for (Index i = 0; i < typePackCount; ++i)
+ {
+ PackExpr* packExpr = nullptr;
+ if (mode == Mode::ForReal)
+ {
+ packExpr = astBuilder->create<PackExpr>();
+ packExpr->loc = loc;
+ }
+ ShortList<Type*> types;
+ for (Index j = 0; j < typePackSize; ++j)
+ {
+ if (packExpr)
+ {
+ auto arg = getArg(fixedParamCount + i * typePackSize + j);
+ packExpr->args.add(arg);
+ }
+ if (computeTypes)
+ types.add(getArgTypeForInference(fixedParamCount + i * typePackSize + j, semantics));
+ }
+ MatchedArg matchedArg;
+ matchedArg.argExpr = packExpr;
+ if (computeTypes)
+ {
+ matchedArg.argType = astBuilder->getTypePack(types.getArrayView().arrayView);
+ if (packExpr)
+ packExpr->type = matchedArg.argType;
+ }
+ outMatchedArgs.add(matchedArg);
+ }
+ return true;
+ }
+
DeclRef<Decl> SemanticsVisitor::inferGenericArguments(
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
@@ -1506,25 +1792,23 @@ namespace Slang
innerParameterTypes = &paramTypes;
}
- Index valueArgCount = context.getArgCount();
- Index valueParamCount = innerParameterTypes->getCount();
+ ShortList<OverloadResolveContext::MatchedArg> matchedArgs;
- // If there are too many arguments, we cannot possibly have a match.
+ // We now try to match arguments to parameters.
//
// Note that if there are *too few* arguments, we might still have
// a match, because the other arguments might have default values
// that can be used.
//
- if (valueArgCount > valueParamCount)
+ if (!context.matchArgumentsToParams(this, *innerParameterTypes, true, matchedArgs))
{
return DeclRef<Decl>();
}
- // If any of the arguments were specified explicitly (and are thus known),
- // we do not want to take them into account during the unification and
- // constraint generation step.
+ // Perform type unification between arguments and parameters, so
+ // we can populate the resolve system with inital constraints.
//
- for (Index aa = 0; aa < valueArgCount; ++aa)
+ for (Index aa = 0; aa < matchedArgs.getCount(); ++aa)
{
// The question here is whether failure to "unify" an argument
// and parameter should lead to immediate failure.
@@ -1543,12 +1827,19 @@ namespace Slang
//
// So the question is then whether a mismatch during the
// unification step should be taken as an immediate failure...
- auto argType = context.getArgTypeForInference(aa, this);
+ auto argType = matchedArgs[aa].argType;
auto paramType = (*innerParameterTypes)[aa];
- TryUnifyTypes(
+ auto canUnify = TryUnifyTypes(
constraints,
+ ValUnificationContext(),
QualType(argType, paramType.isLeftValue),
paramType);
+
+ // It is an error if we can't unify the argument with a type pack parameter.
+ if (!canUnify && isTypePack(paramType))
+ {
+ return DeclRef<Decl>();
+ }
}
}
else
@@ -1984,7 +2275,7 @@ namespace Slang
context.originalExpr = expr;
context.funcLoc = funcExpr->loc;
context.argCount = expr->arguments.getCount();
- context.args = expr->arguments.getBuffer();
+ context.args = &expr->arguments;
context.loc = expr->loc;
context.sourceScope = m_outerScope;
context.baseExpr = GetBaseExpr(funcExpr);
@@ -2238,7 +2529,7 @@ namespace Slang
context.originalExpr = genericAppExpr;
context.funcLoc = baseExpr->loc;
context.argCount = args.getCount();
- context.args = args.getBuffer();
+ context.args = &args;
context.loc = genericAppExpr->loc;
context.sourceScope = m_outerScope;
context.baseExpr = GetBaseExpr(baseExpr);