summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-linkable.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-08-09 09:43:25 -0700
committerGitHub <noreply@github.com>2025-08-09 16:43:25 +0000
commitdcdebc1a76a0a6ffbfd6a5805354f8f679c60202 (patch)
tree126d60d157e73e401aacf1e13b400b8533ec8828 /source/slang/slang-linkable.cpp
parentfc6aea37483446372425aca8471f0e8bf7c3a910 (diff)
Allow specializing entrypoints with generic value args or variadic types from API (#8119)
Closes #8110. Closes #8011.
Diffstat (limited to 'source/slang/slang-linkable.cpp')
-rw-r--r--source/slang/slang-linkable.cpp70
1 files changed, 42 insertions, 28 deletions
diff --git a/source/slang/slang-linkable.cpp b/source/slang/slang-linkable.cpp
index f7dc28171..0eef62742 100644
--- a/source/slang/slang-linkable.cpp
+++ b/source/slang/slang-linkable.cpp
@@ -371,9 +371,19 @@ RefPtr<ComponentType> ComponentType::specialize(
// (e.g., interface conformance witnesses) that doesn't get
// passed explicitly through the API interface.
//
- RefPtr<SpecializationInfo> specializationInfo =
- _validateSpecializationArgs(specializationArgs.getBuffer(), specializationArgCount, sink);
-
+ Index consumedArgCount = 0;
+ RefPtr<SpecializationInfo> specializationInfo = _validateSpecializationArgs(
+ specializationArgs.getBuffer(),
+ specializationArgCount,
+ consumedArgCount,
+ sink);
+ if (consumedArgCount != specializationArgCount)
+ {
+ sink->diagnose(
+ SourceLoc(),
+ Diagnostics::mismatchSpecializationArguments,
+ Math::Max(consumedArgCount, getSpecializationParamCount(), specializationArgCount));
+ }
return new SpecializedComponentType(this, specializationInfo, specializationArgs, sink);
}
@@ -385,21 +395,6 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize(
{
DiagnosticSink sink(getLinkage()->getSourceManager(), Lexer::sourceLocationLexer);
- // First let's check if the number of arguments given matches
- // the number of parameters that are present on this component type.
- //
- auto specializationParamCount = getSpecializationParamCount();
- if (specializationArgCount != specializationParamCount)
- {
- sink.diagnose(
- SourceLoc(),
- Diagnostics::mismatchSpecializationArguments,
- specializationParamCount,
- specializationArgCount);
- sink.getBlobIfNeeded(outDiagnostics);
- return SLANG_FAIL;
- }
-
List<SpecializationArg> expandedArgs;
for (Int aa = 0; aa < specializationArgCount; ++aa)
{
@@ -411,7 +406,21 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize(
case slang::SpecializationArg::Kind::Type:
expandedArg.val = asInternal(apiArg.type);
break;
-
+ case slang::SpecializationArg::Kind::Expr:
+ {
+ auto parsedExpr = parseExprFromString(apiArg.expr, &sink);
+ if (!parsedExpr)
+ return SLANG_FAIL;
+
+ SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, &sink);
+ SemanticsVisitor visitor(&sharedSemanticsContext);
+ auto checkedExpr = visitor.CheckTerm(parsedExpr);
+ if (auto typeType = as<TypeType>(checkedExpr->type.type))
+ expandedArg.val = typeType->getType();
+ else
+ expandedArg.expr = checkedExpr;
+ }
+ break;
default:
sink.getBlobIfNeeded(outDiagnostics);
return SLANG_FAIL;
@@ -425,7 +434,8 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize(
sink.getBlobIfNeeded(outDiagnostics);
*outSpecializedComponentType = specializedComponentType.detach();
-
+ if (sink.getErrorCount() != 0)
+ return SLANG_FAIL;
return SLANG_OK;
}
@@ -729,6 +739,18 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetMetadata(
return SLANG_OK;
}
+Expr* ComponentType::parseExprFromString(String exprStr, DiagnosticSink* sink)
+{
+ auto linkage = getLinkage();
+ SLANG_AST_BUILDER_RAII(linkage->getASTBuilder());
+ auto astBuilder = linkage->getASTBuilder();
+ Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder);
+ Expr* expr = linkage->parseTermString(exprStr, scope);
+ if (!expr || as<IncompleteExpr>(expr))
+ sink->diagnose(SourceLoc(), Diagnostics::syntaxError);
+ return expr;
+}
+
Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* sink)
{
// If we've looked up this type name before,
@@ -738,14 +760,6 @@ Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* si
if (m_types.tryGetValue(typeStr, type))
return type;
-
- // TODO(JS): For now just used the linkages ASTBuilder to keep on scope
- //
- // The parseTermString uses the linkage ASTBuilder for it's parsing.
- //
- // It might be possible to just create a temporary ASTBuilder - the worry though is
- // that the parsing sets a member variable in AST node to one of these scopes, and then
- // it become a dangling pointer. So for now we go with the linkages.
auto astBuilder = getLinkage()->getASTBuilder();
// Otherwise, we need to start looking in