summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-shader.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-shader.cpp')
-rw-r--r--source/slang/slang-check-shader.cpp174
1 files changed, 100 insertions, 74 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index dc4f84920..45deea109 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -238,6 +238,25 @@ DeclRef<FuncDecl> findFunctionDeclByName(Module* translationUnit, Name* name, Di
{
entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>();
+ if (!entryPointFuncDeclRef)
+ {
+ if (auto genDeclRef = as<GenericDecl>(declRefExpr->declRef))
+ {
+ SharedSemanticsContext context(
+ translationUnit->getLinkage(),
+ translationUnit,
+ sink);
+ SemanticsVisitor visitor(&context);
+ entryPointFuncDeclRef = createDefaultSubstitutionsIfNeeded(
+ translationUnit->getASTBuilder(),
+ &visitor,
+ translationUnit->getASTBuilder()->getMemberDeclRef(
+ genDeclRef,
+ genDeclRef.getDecl()->inner))
+ .as<FuncDecl>();
+ }
+ }
+
if (entryPointFuncDeclRef && getModule(entryPointFuncDeclRef.getDecl()) != translationUnit)
entryPointFuncDeclRef = DeclRef<FuncDecl>();
}
@@ -1251,9 +1270,20 @@ RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType(
RefPtr<ComponentType::SpecializationInfo> Module::_validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink)
{
- SLANG_ASSERT(argCount == getSpecializationParamCount());
+ if (argCount < getSpecializationParamCount())
+ {
+ sink->diagnose(
+ SourceLoc(),
+ Diagnostics::mismatchSpecializationArguments,
+ getSpecializationParamCount(),
+ argCount);
+ return nullptr;
+ }
+ outConsumedArgCount = getSpecializationParamCount();
+ argCount = outConsumedArgCount;
SharedSemanticsContext semanticsContext(getLinkage(), this, sink);
SemanticsVisitor visitor(&semanticsContext);
@@ -1468,6 +1498,7 @@ static void _extractSpecializationArgs(
RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArgsImpl(
SpecializationArg const* inArgs,
Index inArgCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink)
{
auto args = inArgs;
@@ -1476,15 +1507,16 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg
SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, sink);
SemanticsVisitor visitor(&sharedSemanticsContext);
- // The first N arguments will be for the explicit generic parameters
+ // The last N arguments will be for the implicit existential arguments
// of the entry point (if it has any).
//
+ auto existentialSpecializationParamCount = getExistentialSpecializationParamCount();
auto genericSpecializationParamCount = getGenericSpecializationParamCount();
- SLANG_ASSERT(argCount >= genericSpecializationParamCount);
RefPtr<EntryPointSpecializationInfo> info = new EntryPointSpecializationInfo();
DeclRef<FuncDecl> specializedFuncDeclRef = m_funcDeclRef;
+ Index genericArgCount = genericSpecializationParamCount;
if (genericSpecializationParamCount)
{
// We need to construct a generic application and use
@@ -1494,83 +1526,69 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg
auto genericDeclRef = m_funcDeclRef.getParent().as<GenericDecl>();
SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters
- List<Val*> genericArgs;
+ bool isVariadic =
+ (genericDeclRef.getDecl()->getMembersOfType<GenericTypePackParamDecl>().getCount() !=
+ 0);
+
+ // If function is variadic generic, it will consume all the provided arguments.
+ if (isVariadic)
+ genericArgCount = argCount - existentialSpecializationParamCount;
- for (Index ii = 0; ii < genericSpecializationParamCount; ++ii)
+ if (genericArgCount < 0)
{
- auto specializationArg = args[ii];
- genericArgs.add(specializationArg.val);
+ sink->diagnose(
+ SourceLoc(),
+ Diagnostics::mismatchSpecializationArguments,
+ genericSpecializationParamCount + existentialSpecializationParamCount,
+ argCount);
+ return nullptr;
}
- auto astBuilder = getLinkage()->getASTBuilder();
- for (auto constraintDecl : getMembersOfType<GenericTypeConstraintDecl>(
- getLinkage()->getASTBuilder(),
- DeclRef<ContainerDecl>(genericDeclRef)))
- {
- DeclRef<GenericTypeConstraintDecl> constraintDeclRef =
- astBuilder->getDirectDeclRef(constraintDecl.getDecl());
- int argIndex = -1;
- int ii = 0;
-
- // Find the generic parameter type (T) that this constraint (T:IFoo) is applying to.
- auto genericParamType = getSub(astBuilder, constraintDeclRef);
- auto genParamDeclRefType = as<DeclRefType>(genericParamType);
- if (!genParamDeclRefType)
- {
- continue;
- }
- auto genParamDeclRef = genParamDeclRefType->getDeclRef();
- // Find the generic argument index of the corresponding generic parameter type in the
- // generic parameter set.
- //
- for (auto member : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
- {
- if (member == genParamDeclRef.getDecl())
- {
- argIndex = ii;
- break;
- }
- ii++;
- }
- if (argIndex == -1)
- {
- SLANG_ASSERT(!"generic parameter not found in generic decl");
- continue;
- }
- auto sub = as<Type>(args[argIndex].val);
- if (!sub)
- {
- sink->diagnose(
- constraintDecl,
- Diagnostics::expectedTypeForSpecializationArg,
- argIndex);
- continue;
- }
+ List<Expr*> genericArgs;
- auto sup = getSup(astBuilder, constraintDeclRef);
- auto subTypeWitness = visitor.isSubtype(sub, sup, IsSubTypeOptions::None);
- if (subTypeWitness)
- {
- genericArgs.add(subTypeWitness);
- }
- else
+ auto astBuilder = getLinkage()->getASTBuilder();
+ for (Index ii = 0; ii < genericArgCount; ++ii)
+ {
+ auto specializationArg = args[ii];
+ if (specializationArg.expr)
{
- // TODO: diagnose a problem here
- sink->diagnose(
- constraintDecl,
- Diagnostics::typeArgumentDoesNotConformToInterface,
- sub,
- sup);
+ genericArgs.add(specializationArg.expr);
continue;
}
+ auto typeExpr = astBuilder->create<SharedTypeExpr>();
+ typeExpr->type = astBuilder->getTypeType((Type*)specializationArg.val);
+ genericArgs.add(typeExpr);
+ }
+ auto genAppExpr = astBuilder->create<GenericAppExpr>();
+ auto genExpr = astBuilder->create<VarExpr>();
+ genExpr->declRef = genericDeclRef;
+ genExpr->type = astBuilder->getOrCreate<GenericDeclRefType>();
+ genExpr->checked = true;
+ genAppExpr->functionExpr = genExpr;
+ genAppExpr->arguments = _Move(genericArgs);
+ auto checkedExpr = visitor.CheckTerm(genAppExpr);
+ if (auto partiallyAppliedExpr = as<PartiallyAppliedGenericExpr>(checkedExpr))
+ {
+ // If checked generic is partially applied generic, we try to force conversion into
+ // a fully defined declref by calling `trySolveConstraintSystem`.
+ SemanticsVisitor::ConstraintSystem system;
+ system.genericDecl = genericDeclRef.getDecl();
+ ConversionCost outCost;
+ specializedFuncDeclRef = visitor
+ .trySolveConstraintSystem(
+ &system,
+ genericDeclRef,
+ partiallyAppliedExpr->knownGenericArgs.getArrayView(),
+ outCost)
+ .as<FuncDecl>();
+ }
+ else if (auto declRefExpr = as<DeclRefExpr>(checkedExpr))
+ {
+ specializedFuncDeclRef = declRefExpr->declRef.as<FuncDecl>();
}
- specializedFuncDeclRef =
- getLinkage()
- ->getASTBuilder()
- ->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView())
- .as<FuncDecl>();
- SLANG_ASSERT(specializedFuncDeclRef);
+ if (!specializedFuncDeclRef)
+ return nullptr;
}
info->specializedFuncDeclRef = specializedFuncDeclRef;
@@ -1580,11 +1598,19 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg
// specialization parameters, attached to the value parameters
// of the entry point.
//
- args += genericSpecializationParamCount;
- argCount -= genericSpecializationParamCount;
+ args += genericArgCount;
+ argCount -= genericArgCount;
+ outConsumedArgCount = genericArgCount + existentialSpecializationParamCount;
- auto existentialSpecializationParamCount = getExistentialSpecializationParamCount();
- SLANG_ASSERT(argCount == existentialSpecializationParamCount);
+ if (argCount < existentialSpecializationParamCount)
+ {
+ sink->diagnose(
+ SourceLoc(),
+ Diagnostics::mismatchSpecializationArguments,
+ genericSpecializationParamCount + existentialSpecializationParamCount,
+ argCount);
+ return nullptr;
+ }
for (Index ii = 0; ii < existentialSpecializationParamCount; ++ii)
{