summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-08-09 09:43:25 -0700
committerGitHub <noreply@github.com>2025-08-09 16:43:25 +0000
commitdcdebc1a76a0a6ffbfd6a5805354f8f679c60202 (patch)
tree126d60d157e73e401aacf1e13b400b8533ec8828 /source/slang
parentfc6aea37483446372425aca8471f0e8bf7c3a910 (diff)
Allow specializing entrypoints with generic value args or variadic types from API (#8119)
Closes #8110. Closes #8011.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-support-types.h1
-rw-r--r--source/slang/slang-check-shader.cpp174
-rw-r--r--source/slang/slang-entry-point.h1
-rw-r--r--source/slang/slang-linkable-impls.cpp18
-rw-r--r--source/slang/slang-linkable-impls.h7
-rw-r--r--source/slang/slang-linkable.cpp70
-rw-r--r--source/slang/slang-linkable.h6
-rw-r--r--source/slang/slang-module.h1
8 files changed, 166 insertions, 112 deletions
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 6dc4203a6..9ad58c776 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1627,6 +1627,7 @@ FIDDLE() namespace Slang
struct SpecializationArg
{
Val* val = nullptr;
+ Expr* expr = nullptr;
};
typedef List<SpecializationArg> SpecializationArgs;
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index dc4f84920..45deea109 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -238,6 +238,25 @@ DeclRef<FuncDecl> findFunctionDeclByName(Module* translationUnit, Name* name, Di
{
entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>();
+ if (!entryPointFuncDeclRef)
+ {
+ if (auto genDeclRef = as<GenericDecl>(declRefExpr->declRef))
+ {
+ SharedSemanticsContext context(
+ translationUnit->getLinkage(),
+ translationUnit,
+ sink);
+ SemanticsVisitor visitor(&context);
+ entryPointFuncDeclRef = createDefaultSubstitutionsIfNeeded(
+ translationUnit->getASTBuilder(),
+ &visitor,
+ translationUnit->getASTBuilder()->getMemberDeclRef(
+ genDeclRef,
+ genDeclRef.getDecl()->inner))
+ .as<FuncDecl>();
+ }
+ }
+
if (entryPointFuncDeclRef && getModule(entryPointFuncDeclRef.getDecl()) != translationUnit)
entryPointFuncDeclRef = DeclRef<FuncDecl>();
}
@@ -1251,9 +1270,20 @@ RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType(
RefPtr<ComponentType::SpecializationInfo> Module::_validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink)
{
- SLANG_ASSERT(argCount == getSpecializationParamCount());
+ if (argCount < getSpecializationParamCount())
+ {
+ sink->diagnose(
+ SourceLoc(),
+ Diagnostics::mismatchSpecializationArguments,
+ getSpecializationParamCount(),
+ argCount);
+ return nullptr;
+ }
+ outConsumedArgCount = getSpecializationParamCount();
+ argCount = outConsumedArgCount;
SharedSemanticsContext semanticsContext(getLinkage(), this, sink);
SemanticsVisitor visitor(&semanticsContext);
@@ -1468,6 +1498,7 @@ static void _extractSpecializationArgs(
RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArgsImpl(
SpecializationArg const* inArgs,
Index inArgCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink)
{
auto args = inArgs;
@@ -1476,15 +1507,16 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg
SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, sink);
SemanticsVisitor visitor(&sharedSemanticsContext);
- // The first N arguments will be for the explicit generic parameters
+ // The last N arguments will be for the implicit existential arguments
// of the entry point (if it has any).
//
+ auto existentialSpecializationParamCount = getExistentialSpecializationParamCount();
auto genericSpecializationParamCount = getGenericSpecializationParamCount();
- SLANG_ASSERT(argCount >= genericSpecializationParamCount);
RefPtr<EntryPointSpecializationInfo> info = new EntryPointSpecializationInfo();
DeclRef<FuncDecl> specializedFuncDeclRef = m_funcDeclRef;
+ Index genericArgCount = genericSpecializationParamCount;
if (genericSpecializationParamCount)
{
// We need to construct a generic application and use
@@ -1494,83 +1526,69 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg
auto genericDeclRef = m_funcDeclRef.getParent().as<GenericDecl>();
SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters
- List<Val*> genericArgs;
+ bool isVariadic =
+ (genericDeclRef.getDecl()->getMembersOfType<GenericTypePackParamDecl>().getCount() !=
+ 0);
+
+ // If function is variadic generic, it will consume all the provided arguments.
+ if (isVariadic)
+ genericArgCount = argCount - existentialSpecializationParamCount;
- for (Index ii = 0; ii < genericSpecializationParamCount; ++ii)
+ if (genericArgCount < 0)
{
- auto specializationArg = args[ii];
- genericArgs.add(specializationArg.val);
+ sink->diagnose(
+ SourceLoc(),
+ Diagnostics::mismatchSpecializationArguments,
+ genericSpecializationParamCount + existentialSpecializationParamCount,
+ argCount);
+ return nullptr;
}
- auto astBuilder = getLinkage()->getASTBuilder();
- for (auto constraintDecl : getMembersOfType<GenericTypeConstraintDecl>(
- getLinkage()->getASTBuilder(),
- DeclRef<ContainerDecl>(genericDeclRef)))
- {
- DeclRef<GenericTypeConstraintDecl> constraintDeclRef =
- astBuilder->getDirectDeclRef(constraintDecl.getDecl());
- int argIndex = -1;
- int ii = 0;
-
- // Find the generic parameter type (T) that this constraint (T:IFoo) is applying to.
- auto genericParamType = getSub(astBuilder, constraintDeclRef);
- auto genParamDeclRefType = as<DeclRefType>(genericParamType);
- if (!genParamDeclRefType)
- {
- continue;
- }
- auto genParamDeclRef = genParamDeclRefType->getDeclRef();
- // Find the generic argument index of the corresponding generic parameter type in the
- // generic parameter set.
- //
- for (auto member : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
- {
- if (member == genParamDeclRef.getDecl())
- {
- argIndex = ii;
- break;
- }
- ii++;
- }
- if (argIndex == -1)
- {
- SLANG_ASSERT(!"generic parameter not found in generic decl");
- continue;
- }
- auto sub = as<Type>(args[argIndex].val);
- if (!sub)
- {
- sink->diagnose(
- constraintDecl,
- Diagnostics::expectedTypeForSpecializationArg,
- argIndex);
- continue;
- }
+ List<Expr*> genericArgs;
- auto sup = getSup(astBuilder, constraintDeclRef);
- auto subTypeWitness = visitor.isSubtype(sub, sup, IsSubTypeOptions::None);
- if (subTypeWitness)
- {
- genericArgs.add(subTypeWitness);
- }
- else
+ auto astBuilder = getLinkage()->getASTBuilder();
+ for (Index ii = 0; ii < genericArgCount; ++ii)
+ {
+ auto specializationArg = args[ii];
+ if (specializationArg.expr)
{
- // TODO: diagnose a problem here
- sink->diagnose(
- constraintDecl,
- Diagnostics::typeArgumentDoesNotConformToInterface,
- sub,
- sup);
+ genericArgs.add(specializationArg.expr);
continue;
}
+ auto typeExpr = astBuilder->create<SharedTypeExpr>();
+ typeExpr->type = astBuilder->getTypeType((Type*)specializationArg.val);
+ genericArgs.add(typeExpr);
+ }
+ auto genAppExpr = astBuilder->create<GenericAppExpr>();
+ auto genExpr = astBuilder->create<VarExpr>();
+ genExpr->declRef = genericDeclRef;
+ genExpr->type = astBuilder->getOrCreate<GenericDeclRefType>();
+ genExpr->checked = true;
+ genAppExpr->functionExpr = genExpr;
+ genAppExpr->arguments = _Move(genericArgs);
+ auto checkedExpr = visitor.CheckTerm(genAppExpr);
+ if (auto partiallyAppliedExpr = as<PartiallyAppliedGenericExpr>(checkedExpr))
+ {
+ // If checked generic is partially applied generic, we try to force conversion into
+ // a fully defined declref by calling `trySolveConstraintSystem`.
+ SemanticsVisitor::ConstraintSystem system;
+ system.genericDecl = genericDeclRef.getDecl();
+ ConversionCost outCost;
+ specializedFuncDeclRef = visitor
+ .trySolveConstraintSystem(
+ &system,
+ genericDeclRef,
+ partiallyAppliedExpr->knownGenericArgs.getArrayView(),
+ outCost)
+ .as<FuncDecl>();
+ }
+ else if (auto declRefExpr = as<DeclRefExpr>(checkedExpr))
+ {
+ specializedFuncDeclRef = declRefExpr->declRef.as<FuncDecl>();
}
- specializedFuncDeclRef =
- getLinkage()
- ->getASTBuilder()
- ->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView())
- .as<FuncDecl>();
- SLANG_ASSERT(specializedFuncDeclRef);
+ if (!specializedFuncDeclRef)
+ return nullptr;
}
info->specializedFuncDeclRef = specializedFuncDeclRef;
@@ -1580,11 +1598,19 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg
// specialization parameters, attached to the value parameters
// of the entry point.
//
- args += genericSpecializationParamCount;
- argCount -= genericSpecializationParamCount;
+ args += genericArgCount;
+ argCount -= genericArgCount;
+ outConsumedArgCount = genericArgCount + existentialSpecializationParamCount;
- auto existentialSpecializationParamCount = getExistentialSpecializationParamCount();
- SLANG_ASSERT(argCount == existentialSpecializationParamCount);
+ if (argCount < existentialSpecializationParamCount)
+ {
+ sink->diagnose(
+ SourceLoc(),
+ Diagnostics::mismatchSpecializationArguments,
+ genericSpecializationParamCount + existentialSpecializationParamCount,
+ argCount);
+ return nullptr;
+ }
for (Index ii = 0; ii < existentialSpecializationParamCount; ++ii)
{
diff --git a/source/slang/slang-entry-point.h b/source/slang/slang-entry-point.h
index 16499a542..305e2c77e 100644
--- a/source/slang/slang-entry-point.h
+++ b/source/slang/slang-entry-point.h
@@ -292,6 +292,7 @@ protected:
RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink) SLANG_OVERRIDE;
private:
diff --git a/source/slang/slang-linkable-impls.cpp b/source/slang/slang-linkable-impls.cpp
index d03ecb3ca..082974fae 100644
--- a/source/slang/slang-linkable-impls.cpp
+++ b/source/slang/slang-linkable-impls.cpp
@@ -180,6 +180,7 @@ void CompositeComponentType::acceptVisitor(
RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink)
{
SLANG_UNUSED(argCount);
@@ -189,15 +190,16 @@ RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpeci
Index offset = 0;
for (auto child : m_childComponents)
{
- auto childParamCount = child->getSpecializationParamCount();
- SLANG_ASSERT(offset + childParamCount <= argCount);
-
- auto childInfo = child->_validateSpecializationArgs(args + offset, childParamCount, sink);
-
+ Index consumedArgCount = 0;
+ auto childInfo = child->_validateSpecializationArgs(
+ args + offset,
+ argCount - offset,
+ consumedArgCount,
+ sink);
specializationInfo->childInfos.add(childInfo);
-
- offset += childParamCount;
+ offset += consumedArgCount;
}
+ outConsumedArgCount = offset;
return specializationInfo;
}
@@ -717,11 +719,13 @@ void TypeConformance::acceptVisitor(
RefPtr<ComponentType::SpecializationInfo> TypeConformance::_validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink)
{
SLANG_UNUSED(args);
SLANG_UNUSED(argCount);
SLANG_UNUSED(sink);
+ outConsumedArgCount = 0;
return nullptr;
}
diff --git a/source/slang/slang-linkable-impls.h b/source/slang/slang-linkable-impls.h
index 68a16587d..21ea2fc9f 100644
--- a/source/slang/slang-linkable-impls.h
+++ b/source/slang/slang-linkable-impls.h
@@ -65,6 +65,7 @@ protected:
RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink) SLANG_OVERRIDE;
public:
@@ -165,11 +166,13 @@ protected:
RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink) SLANG_OVERRIDE
{
SLANG_UNUSED(args);
SLANG_UNUSED(argCount);
SLANG_UNUSED(sink);
+ outConsumedArgCount = 0;
return nullptr;
}
@@ -315,9 +318,10 @@ protected:
RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink) SLANG_OVERRIDE
{
- return m_base->_validateSpecializationArgsImpl(args, argCount, sink);
+ return m_base->_validateSpecializationArgsImpl(args, argCount, outConsumedArgCount, sink);
}
};
@@ -513,6 +517,7 @@ protected:
RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink) SLANG_OVERRIDE;
private:
diff --git a/source/slang/slang-linkable.cpp b/source/slang/slang-linkable.cpp
index f7dc28171..0eef62742 100644
--- a/source/slang/slang-linkable.cpp
+++ b/source/slang/slang-linkable.cpp
@@ -371,9 +371,19 @@ RefPtr<ComponentType> ComponentType::specialize(
// (e.g., interface conformance witnesses) that doesn't get
// passed explicitly through the API interface.
//
- RefPtr<SpecializationInfo> specializationInfo =
- _validateSpecializationArgs(specializationArgs.getBuffer(), specializationArgCount, sink);
-
+ Index consumedArgCount = 0;
+ RefPtr<SpecializationInfo> specializationInfo = _validateSpecializationArgs(
+ specializationArgs.getBuffer(),
+ specializationArgCount,
+ consumedArgCount,
+ sink);
+ if (consumedArgCount != specializationArgCount)
+ {
+ sink->diagnose(
+ SourceLoc(),
+ Diagnostics::mismatchSpecializationArguments,
+ Math::Max(consumedArgCount, getSpecializationParamCount(), specializationArgCount));
+ }
return new SpecializedComponentType(this, specializationInfo, specializationArgs, sink);
}
@@ -385,21 +395,6 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize(
{
DiagnosticSink sink(getLinkage()->getSourceManager(), Lexer::sourceLocationLexer);
- // First let's check if the number of arguments given matches
- // the number of parameters that are present on this component type.
- //
- auto specializationParamCount = getSpecializationParamCount();
- if (specializationArgCount != specializationParamCount)
- {
- sink.diagnose(
- SourceLoc(),
- Diagnostics::mismatchSpecializationArguments,
- specializationParamCount,
- specializationArgCount);
- sink.getBlobIfNeeded(outDiagnostics);
- return SLANG_FAIL;
- }
-
List<SpecializationArg> expandedArgs;
for (Int aa = 0; aa < specializationArgCount; ++aa)
{
@@ -411,7 +406,21 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize(
case slang::SpecializationArg::Kind::Type:
expandedArg.val = asInternal(apiArg.type);
break;
-
+ case slang::SpecializationArg::Kind::Expr:
+ {
+ auto parsedExpr = parseExprFromString(apiArg.expr, &sink);
+ if (!parsedExpr)
+ return SLANG_FAIL;
+
+ SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, &sink);
+ SemanticsVisitor visitor(&sharedSemanticsContext);
+ auto checkedExpr = visitor.CheckTerm(parsedExpr);
+ if (auto typeType = as<TypeType>(checkedExpr->type.type))
+ expandedArg.val = typeType->getType();
+ else
+ expandedArg.expr = checkedExpr;
+ }
+ break;
default:
sink.getBlobIfNeeded(outDiagnostics);
return SLANG_FAIL;
@@ -425,7 +434,8 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize(
sink.getBlobIfNeeded(outDiagnostics);
*outSpecializedComponentType = specializedComponentType.detach();
-
+ if (sink.getErrorCount() != 0)
+ return SLANG_FAIL;
return SLANG_OK;
}
@@ -729,6 +739,18 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetMetadata(
return SLANG_OK;
}
+Expr* ComponentType::parseExprFromString(String exprStr, DiagnosticSink* sink)
+{
+ auto linkage = getLinkage();
+ SLANG_AST_BUILDER_RAII(linkage->getASTBuilder());
+ auto astBuilder = linkage->getASTBuilder();
+ Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder);
+ Expr* expr = linkage->parseTermString(exprStr, scope);
+ if (!expr || as<IncompleteExpr>(expr))
+ sink->diagnose(SourceLoc(), Diagnostics::syntaxError);
+ return expr;
+}
+
Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* sink)
{
// If we've looked up this type name before,
@@ -738,14 +760,6 @@ Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* si
if (m_types.tryGetValue(typeStr, type))
return type;
-
- // TODO(JS): For now just used the linkages ASTBuilder to keep on scope
- //
- // The parseTermString uses the linkage ASTBuilder for it's parsing.
- //
- // It might be possible to just create a temporary ASTBuilder - the worry though is
- // that the parsing sets a member variable in AST node to one of these scopes, and then
- // it become a dangling pointer. So for now we go with the linkages.
auto astBuilder = getLinkage()->getASTBuilder();
// Otherwise, we need to start looking in
diff --git a/source/slang/slang-linkable.h b/source/slang/slang-linkable.h
index 04a5c3c88..da97fbf9a 100644
--- a/source/slang/slang-linkable.h
+++ b/source/slang/slang-linkable.h
@@ -262,7 +262,7 @@ public:
/// it only really makes sense on `Module`.
///
Type* getTypeFromString(String const& typeStr, DiagnosticSink* sink);
-
+ Expr* parseExprFromString(String expr, DiagnosticSink* sink);
Expr* findDeclFromString(String const& name, DiagnosticSink* sink);
Expr* tryResolveOverloadedExpr(Expr* exprIn);
@@ -348,6 +348,7 @@ public:
virtual RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink) = 0;
/// Validate the given specialization `args` and compute any side-band specialization info.
@@ -363,11 +364,12 @@ public:
RefPtr<SpecializationInfo> _validateSpecializationArgs(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink)
{
if (argCount == 0)
return nullptr;
- return _validateSpecializationArgsImpl(args, argCount, sink);
+ return _validateSpecializationArgsImpl(args, argCount, outConsumedArgCount, sink);
}
/// Specialize this component type given `specializationArgs`
diff --git a/source/slang/slang-module.h b/source/slang/slang-module.h
index 7a71f242d..bd911f0f0 100644
--- a/source/slang/slang-module.h
+++ b/source/slang/slang-module.h
@@ -451,6 +451,7 @@ protected:
RefPtr<SpecializationInfo> _validateSpecializationArgsImpl(
SpecializationArg const* args,
Index argCount,
+ Index& outConsumedArgCount,
DiagnosticSink* sink) SLANG_OVERRIDE;
private: