From dcdebc1a76a0a6ffbfd6a5805354f8f679c60202 Mon Sep 17 00:00:00 2001 From: Yong He Date: Sat, 9 Aug 2025 09:43:25 -0700 Subject: Allow specializing entrypoints with generic value args or variadic types from API (#8119) Closes #8110. Closes #8011. --- source/slang/slang-check-shader.cpp | 174 +++++++++++++++++++++--------------- 1 file changed, 100 insertions(+), 74 deletions(-) (limited to 'source/slang/slang-check-shader.cpp') diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index dc4f84920..45deea109 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -238,6 +238,25 @@ DeclRef findFunctionDeclByName(Module* translationUnit, Name* name, Di { entryPointFuncDeclRef = declRefExpr->declRef.as(); + if (!entryPointFuncDeclRef) + { + if (auto genDeclRef = as(declRefExpr->declRef)) + { + SharedSemanticsContext context( + translationUnit->getLinkage(), + translationUnit, + sink); + SemanticsVisitor visitor(&context); + entryPointFuncDeclRef = createDefaultSubstitutionsIfNeeded( + translationUnit->getASTBuilder(), + &visitor, + translationUnit->getASTBuilder()->getMemberDeclRef( + genDeclRef, + genDeclRef.getDecl()->inner)) + .as(); + } + } + if (entryPointFuncDeclRef && getModule(entryPointFuncDeclRef.getDecl()) != translationUnit) entryPointFuncDeclRef = DeclRef(); } @@ -1251,9 +1270,20 @@ RefPtr createUnspecializedGlobalAndEntryPointsComponentType( RefPtr Module::_validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) { - SLANG_ASSERT(argCount == getSpecializationParamCount()); + if (argCount < getSpecializationParamCount()) + { + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + getSpecializationParamCount(), + argCount); + return nullptr; + } + outConsumedArgCount = getSpecializationParamCount(); + argCount = outConsumedArgCount; SharedSemanticsContext semanticsContext(getLinkage(), this, sink); SemanticsVisitor visitor(&semanticsContext); @@ -1468,6 +1498,7 @@ static void _extractSpecializationArgs( RefPtr EntryPoint::_validateSpecializationArgsImpl( SpecializationArg const* inArgs, Index inArgCount, + Index& outConsumedArgCount, DiagnosticSink* sink) { auto args = inArgs; @@ -1476,15 +1507,16 @@ RefPtr EntryPoint::_validateSpecializationArg SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, sink); SemanticsVisitor visitor(&sharedSemanticsContext); - // The first N arguments will be for the explicit generic parameters + // The last N arguments will be for the implicit existential arguments // of the entry point (if it has any). // + auto existentialSpecializationParamCount = getExistentialSpecializationParamCount(); auto genericSpecializationParamCount = getGenericSpecializationParamCount(); - SLANG_ASSERT(argCount >= genericSpecializationParamCount); RefPtr info = new EntryPointSpecializationInfo(); DeclRef specializedFuncDeclRef = m_funcDeclRef; + Index genericArgCount = genericSpecializationParamCount; if (genericSpecializationParamCount) { // We need to construct a generic application and use @@ -1494,83 +1526,69 @@ RefPtr EntryPoint::_validateSpecializationArg auto genericDeclRef = m_funcDeclRef.getParent().as(); SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters - List genericArgs; + bool isVariadic = + (genericDeclRef.getDecl()->getMembersOfType().getCount() != + 0); + + // If function is variadic generic, it will consume all the provided arguments. + if (isVariadic) + genericArgCount = argCount - existentialSpecializationParamCount; - for (Index ii = 0; ii < genericSpecializationParamCount; ++ii) + if (genericArgCount < 0) { - auto specializationArg = args[ii]; - genericArgs.add(specializationArg.val); + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + genericSpecializationParamCount + existentialSpecializationParamCount, + argCount); + return nullptr; } - auto astBuilder = getLinkage()->getASTBuilder(); - for (auto constraintDecl : getMembersOfType( - getLinkage()->getASTBuilder(), - DeclRef(genericDeclRef))) - { - DeclRef constraintDeclRef = - astBuilder->getDirectDeclRef(constraintDecl.getDecl()); - int argIndex = -1; - int ii = 0; - - // Find the generic parameter type (T) that this constraint (T:IFoo) is applying to. - auto genericParamType = getSub(astBuilder, constraintDeclRef); - auto genParamDeclRefType = as(genericParamType); - if (!genParamDeclRefType) - { - continue; - } - auto genParamDeclRef = genParamDeclRefType->getDeclRef(); - // Find the generic argument index of the corresponding generic parameter type in the - // generic parameter set. - // - for (auto member : genericDeclRef.getDecl()->getMembersOfType()) - { - if (member == genParamDeclRef.getDecl()) - { - argIndex = ii; - break; - } - ii++; - } - if (argIndex == -1) - { - SLANG_ASSERT(!"generic parameter not found in generic decl"); - continue; - } - auto sub = as(args[argIndex].val); - if (!sub) - { - sink->diagnose( - constraintDecl, - Diagnostics::expectedTypeForSpecializationArg, - argIndex); - continue; - } + List genericArgs; - auto sup = getSup(astBuilder, constraintDeclRef); - auto subTypeWitness = visitor.isSubtype(sub, sup, IsSubTypeOptions::None); - if (subTypeWitness) - { - genericArgs.add(subTypeWitness); - } - else + auto astBuilder = getLinkage()->getASTBuilder(); + for (Index ii = 0; ii < genericArgCount; ++ii) + { + auto specializationArg = args[ii]; + if (specializationArg.expr) { - // TODO: diagnose a problem here - sink->diagnose( - constraintDecl, - Diagnostics::typeArgumentDoesNotConformToInterface, - sub, - sup); + genericArgs.add(specializationArg.expr); continue; } + auto typeExpr = astBuilder->create(); + typeExpr->type = astBuilder->getTypeType((Type*)specializationArg.val); + genericArgs.add(typeExpr); + } + auto genAppExpr = astBuilder->create(); + auto genExpr = astBuilder->create(); + genExpr->declRef = genericDeclRef; + genExpr->type = astBuilder->getOrCreate(); + genExpr->checked = true; + genAppExpr->functionExpr = genExpr; + genAppExpr->arguments = _Move(genericArgs); + auto checkedExpr = visitor.CheckTerm(genAppExpr); + if (auto partiallyAppliedExpr = as(checkedExpr)) + { + // If checked generic is partially applied generic, we try to force conversion into + // a fully defined declref by calling `trySolveConstraintSystem`. + SemanticsVisitor::ConstraintSystem system; + system.genericDecl = genericDeclRef.getDecl(); + ConversionCost outCost; + specializedFuncDeclRef = visitor + .trySolveConstraintSystem( + &system, + genericDeclRef, + partiallyAppliedExpr->knownGenericArgs.getArrayView(), + outCost) + .as(); + } + else if (auto declRefExpr = as(checkedExpr)) + { + specializedFuncDeclRef = declRefExpr->declRef.as(); } - specializedFuncDeclRef = - getLinkage() - ->getASTBuilder() - ->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView()) - .as(); - SLANG_ASSERT(specializedFuncDeclRef); + if (!specializedFuncDeclRef) + return nullptr; } info->specializedFuncDeclRef = specializedFuncDeclRef; @@ -1580,11 +1598,19 @@ RefPtr EntryPoint::_validateSpecializationArg // specialization parameters, attached to the value parameters // of the entry point. // - args += genericSpecializationParamCount; - argCount -= genericSpecializationParamCount; + args += genericArgCount; + argCount -= genericArgCount; + outConsumedArgCount = genericArgCount + existentialSpecializationParamCount; - auto existentialSpecializationParamCount = getExistentialSpecializationParamCount(); - SLANG_ASSERT(argCount == existentialSpecializationParamCount); + if (argCount < existentialSpecializationParamCount) + { + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + genericSpecializationParamCount + existentialSpecializationParamCount, + argCount); + return nullptr; + } for (Index ii = 0; ii < existentialSpecializationParamCount; ++ii) { -- cgit v1.2.3