diff options
| author | Yong He <yonghe@outlook.com> | 2025-08-09 09:43:25 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-08-09 16:43:25 +0000 |
| commit | dcdebc1a76a0a6ffbfd6a5805354f8f679c60202 (patch) | |
| tree | 126d60d157e73e401aacf1e13b400b8533ec8828 /source/slang/slang-linkable.cpp | |
| parent | fc6aea37483446372425aca8471f0e8bf7c3a910 (diff) | |
Allow specializing entrypoints with generic value args or variadic types from API (#8119)
Closes #8110.
Closes #8011.
Diffstat (limited to 'source/slang/slang-linkable.cpp')
| -rw-r--r-- | source/slang/slang-linkable.cpp | 70 |
1 files changed, 42 insertions, 28 deletions
diff --git a/source/slang/slang-linkable.cpp b/source/slang/slang-linkable.cpp index f7dc28171..0eef62742 100644 --- a/source/slang/slang-linkable.cpp +++ b/source/slang/slang-linkable.cpp @@ -371,9 +371,19 @@ RefPtr<ComponentType> ComponentType::specialize( // (e.g., interface conformance witnesses) that doesn't get // passed explicitly through the API interface. // - RefPtr<SpecializationInfo> specializationInfo = - _validateSpecializationArgs(specializationArgs.getBuffer(), specializationArgCount, sink); - + Index consumedArgCount = 0; + RefPtr<SpecializationInfo> specializationInfo = _validateSpecializationArgs( + specializationArgs.getBuffer(), + specializationArgCount, + consumedArgCount, + sink); + if (consumedArgCount != specializationArgCount) + { + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + Math::Max(consumedArgCount, getSpecializationParamCount(), specializationArgCount)); + } return new SpecializedComponentType(this, specializationInfo, specializationArgs, sink); } @@ -385,21 +395,6 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize( { DiagnosticSink sink(getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); - // First let's check if the number of arguments given matches - // the number of parameters that are present on this component type. - // - auto specializationParamCount = getSpecializationParamCount(); - if (specializationArgCount != specializationParamCount) - { - sink.diagnose( - SourceLoc(), - Diagnostics::mismatchSpecializationArguments, - specializationParamCount, - specializationArgCount); - sink.getBlobIfNeeded(outDiagnostics); - return SLANG_FAIL; - } - List<SpecializationArg> expandedArgs; for (Int aa = 0; aa < specializationArgCount; ++aa) { @@ -411,7 +406,21 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize( case slang::SpecializationArg::Kind::Type: expandedArg.val = asInternal(apiArg.type); break; - + case slang::SpecializationArg::Kind::Expr: + { + auto parsedExpr = parseExprFromString(apiArg.expr, &sink); + if (!parsedExpr) + return SLANG_FAIL; + + SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, &sink); + SemanticsVisitor visitor(&sharedSemanticsContext); + auto checkedExpr = visitor.CheckTerm(parsedExpr); + if (auto typeType = as<TypeType>(checkedExpr->type.type)) + expandedArg.val = typeType->getType(); + else + expandedArg.expr = checkedExpr; + } + break; default: sink.getBlobIfNeeded(outDiagnostics); return SLANG_FAIL; @@ -425,7 +434,8 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize( sink.getBlobIfNeeded(outDiagnostics); *outSpecializedComponentType = specializedComponentType.detach(); - + if (sink.getErrorCount() != 0) + return SLANG_FAIL; return SLANG_OK; } @@ -729,6 +739,18 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetMetadata( return SLANG_OK; } +Expr* ComponentType::parseExprFromString(String exprStr, DiagnosticSink* sink) +{ + auto linkage = getLinkage(); + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + auto astBuilder = linkage->getASTBuilder(); + Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); + Expr* expr = linkage->parseTermString(exprStr, scope); + if (!expr || as<IncompleteExpr>(expr)) + sink->diagnose(SourceLoc(), Diagnostics::syntaxError); + return expr; +} + Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* sink) { // If we've looked up this type name before, @@ -738,14 +760,6 @@ Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* si if (m_types.tryGetValue(typeStr, type)) return type; - - // TODO(JS): For now just used the linkages ASTBuilder to keep on scope - // - // The parseTermString uses the linkage ASTBuilder for it's parsing. - // - // It might be possible to just create a temporary ASTBuilder - the worry though is - // that the parsing sets a member variable in AST node to one of these scopes, and then - // it become a dangling pointer. So for now we go with the linkages. auto astBuilder = getLinkage()->getASTBuilder(); // Otherwise, we need to start looking in |
