diff options
| -rw-r--r-- | include/slang.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 174 | ||||
| -rw-r--r-- | source/slang/slang-entry-point.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-linkable-impls.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-linkable-impls.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-linkable.cpp | 70 | ||||
| -rw-r--r-- | source/slang/slang-linkable.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-module.h | 1 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-generic-entrypoint.cpp | 119 |
10 files changed, 274 insertions, 134 deletions
diff --git a/include/slang.h b/include/slang.h index 7462644a2..a32cbfd3b 100644 --- a/include/slang.h +++ b/include/slang.h @@ -4621,6 +4621,7 @@ struct SpecializationArg { Unknown, /**< An invalid specialization argument. */ Type, /**< Specialize to a type. */ + Expr, /**< An expression representing a type or value */ }; /** The kind of specialization argument. */ @@ -4629,6 +4630,8 @@ struct SpecializationArg { /** A type specialization argument, used for `Kind::Type`. */ TypeReflection* type; + /** An expression in Slang syntax, used for `Kind::Expr`. */ + const char* expr; }; static SpecializationArg fromType(TypeReflection* inType) @@ -4638,6 +4641,14 @@ struct SpecializationArg rs.type = inType; return rs; } + + static SpecializationArg fromExpr(const char* inExpr) + { + SpecializationArg rs; + rs.kind = Kind::Expr; + rs.expr = inExpr; + return rs; + } }; } // namespace slang diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 6dc4203a6..9ad58c776 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1627,6 +1627,7 @@ FIDDLE() namespace Slang struct SpecializationArg { Val* val = nullptr; + Expr* expr = nullptr; }; typedef List<SpecializationArg> SpecializationArgs; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index dc4f84920..45deea109 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -238,6 +238,25 @@ DeclRef<FuncDecl> findFunctionDeclByName(Module* translationUnit, Name* name, Di { entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>(); + if (!entryPointFuncDeclRef) + { + if (auto genDeclRef = as<GenericDecl>(declRefExpr->declRef)) + { + SharedSemanticsContext context( + translationUnit->getLinkage(), + translationUnit, + sink); + SemanticsVisitor visitor(&context); + entryPointFuncDeclRef = createDefaultSubstitutionsIfNeeded( + translationUnit->getASTBuilder(), + &visitor, + translationUnit->getASTBuilder()->getMemberDeclRef( + genDeclRef, + genDeclRef.getDecl()->inner)) + .as<FuncDecl>(); + } + } + if (entryPointFuncDeclRef && getModule(entryPointFuncDeclRef.getDecl()) != translationUnit) entryPointFuncDeclRef = DeclRef<FuncDecl>(); } @@ -1251,9 +1270,20 @@ RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType( RefPtr<ComponentType::SpecializationInfo> Module::_validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) { - SLANG_ASSERT(argCount == getSpecializationParamCount()); + if (argCount < getSpecializationParamCount()) + { + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + getSpecializationParamCount(), + argCount); + return nullptr; + } + outConsumedArgCount = getSpecializationParamCount(); + argCount = outConsumedArgCount; SharedSemanticsContext semanticsContext(getLinkage(), this, sink); SemanticsVisitor visitor(&semanticsContext); @@ -1468,6 +1498,7 @@ static void _extractSpecializationArgs( RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArgsImpl( SpecializationArg const* inArgs, Index inArgCount, + Index& outConsumedArgCount, DiagnosticSink* sink) { auto args = inArgs; @@ -1476,15 +1507,16 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, sink); SemanticsVisitor visitor(&sharedSemanticsContext); - // The first N arguments will be for the explicit generic parameters + // The last N arguments will be for the implicit existential arguments // of the entry point (if it has any). // + auto existentialSpecializationParamCount = getExistentialSpecializationParamCount(); auto genericSpecializationParamCount = getGenericSpecializationParamCount(); - SLANG_ASSERT(argCount >= genericSpecializationParamCount); RefPtr<EntryPointSpecializationInfo> info = new EntryPointSpecializationInfo(); DeclRef<FuncDecl> specializedFuncDeclRef = m_funcDeclRef; + Index genericArgCount = genericSpecializationParamCount; if (genericSpecializationParamCount) { // We need to construct a generic application and use @@ -1494,83 +1526,69 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg auto genericDeclRef = m_funcDeclRef.getParent().as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters - List<Val*> genericArgs; + bool isVariadic = + (genericDeclRef.getDecl()->getMembersOfType<GenericTypePackParamDecl>().getCount() != + 0); + + // If function is variadic generic, it will consume all the provided arguments. + if (isVariadic) + genericArgCount = argCount - existentialSpecializationParamCount; - for (Index ii = 0; ii < genericSpecializationParamCount; ++ii) + if (genericArgCount < 0) { - auto specializationArg = args[ii]; - genericArgs.add(specializationArg.val); + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + genericSpecializationParamCount + existentialSpecializationParamCount, + argCount); + return nullptr; } - auto astBuilder = getLinkage()->getASTBuilder(); - for (auto constraintDecl : getMembersOfType<GenericTypeConstraintDecl>( - getLinkage()->getASTBuilder(), - DeclRef<ContainerDecl>(genericDeclRef))) - { - DeclRef<GenericTypeConstraintDecl> constraintDeclRef = - astBuilder->getDirectDeclRef(constraintDecl.getDecl()); - int argIndex = -1; - int ii = 0; - - // Find the generic parameter type (T) that this constraint (T:IFoo) is applying to. - auto genericParamType = getSub(astBuilder, constraintDeclRef); - auto genParamDeclRefType = as<DeclRefType>(genericParamType); - if (!genParamDeclRefType) - { - continue; - } - auto genParamDeclRef = genParamDeclRefType->getDeclRef(); - // Find the generic argument index of the corresponding generic parameter type in the - // generic parameter set. - // - for (auto member : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>()) - { - if (member == genParamDeclRef.getDecl()) - { - argIndex = ii; - break; - } - ii++; - } - if (argIndex == -1) - { - SLANG_ASSERT(!"generic parameter not found in generic decl"); - continue; - } - auto sub = as<Type>(args[argIndex].val); - if (!sub) - { - sink->diagnose( - constraintDecl, - Diagnostics::expectedTypeForSpecializationArg, - argIndex); - continue; - } + List<Expr*> genericArgs; - auto sup = getSup(astBuilder, constraintDeclRef); - auto subTypeWitness = visitor.isSubtype(sub, sup, IsSubTypeOptions::None); - if (subTypeWitness) - { - genericArgs.add(subTypeWitness); - } - else + auto astBuilder = getLinkage()->getASTBuilder(); + for (Index ii = 0; ii < genericArgCount; ++ii) + { + auto specializationArg = args[ii]; + if (specializationArg.expr) { - // TODO: diagnose a problem here - sink->diagnose( - constraintDecl, - Diagnostics::typeArgumentDoesNotConformToInterface, - sub, - sup); + genericArgs.add(specializationArg.expr); continue; } + auto typeExpr = astBuilder->create<SharedTypeExpr>(); + typeExpr->type = astBuilder->getTypeType((Type*)specializationArg.val); + genericArgs.add(typeExpr); + } + auto genAppExpr = astBuilder->create<GenericAppExpr>(); + auto genExpr = astBuilder->create<VarExpr>(); + genExpr->declRef = genericDeclRef; + genExpr->type = astBuilder->getOrCreate<GenericDeclRefType>(); + genExpr->checked = true; + genAppExpr->functionExpr = genExpr; + genAppExpr->arguments = _Move(genericArgs); + auto checkedExpr = visitor.CheckTerm(genAppExpr); + if (auto partiallyAppliedExpr = as<PartiallyAppliedGenericExpr>(checkedExpr)) + { + // If checked generic is partially applied generic, we try to force conversion into + // a fully defined declref by calling `trySolveConstraintSystem`. + SemanticsVisitor::ConstraintSystem system; + system.genericDecl = genericDeclRef.getDecl(); + ConversionCost outCost; + specializedFuncDeclRef = visitor + .trySolveConstraintSystem( + &system, + genericDeclRef, + partiallyAppliedExpr->knownGenericArgs.getArrayView(), + outCost) + .as<FuncDecl>(); + } + else if (auto declRefExpr = as<DeclRefExpr>(checkedExpr)) + { + specializedFuncDeclRef = declRefExpr->declRef.as<FuncDecl>(); } - specializedFuncDeclRef = - getLinkage() - ->getASTBuilder() - ->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView()) - .as<FuncDecl>(); - SLANG_ASSERT(specializedFuncDeclRef); + if (!specializedFuncDeclRef) + return nullptr; } info->specializedFuncDeclRef = specializedFuncDeclRef; @@ -1580,11 +1598,19 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg // specialization parameters, attached to the value parameters // of the entry point. // - args += genericSpecializationParamCount; - argCount -= genericSpecializationParamCount; + args += genericArgCount; + argCount -= genericArgCount; + outConsumedArgCount = genericArgCount + existentialSpecializationParamCount; - auto existentialSpecializationParamCount = getExistentialSpecializationParamCount(); - SLANG_ASSERT(argCount == existentialSpecializationParamCount); + if (argCount < existentialSpecializationParamCount) + { + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + genericSpecializationParamCount + existentialSpecializationParamCount, + argCount); + return nullptr; + } for (Index ii = 0; ii < existentialSpecializationParamCount; ++ii) { diff --git a/source/slang/slang-entry-point.h b/source/slang/slang-entry-point.h index 16499a542..305e2c77e 100644 --- a/source/slang/slang-entry-point.h +++ b/source/slang/slang-entry-point.h @@ -292,6 +292,7 @@ protected: RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) SLANG_OVERRIDE; private: diff --git a/source/slang/slang-linkable-impls.cpp b/source/slang/slang-linkable-impls.cpp index d03ecb3ca..082974fae 100644 --- a/source/slang/slang-linkable-impls.cpp +++ b/source/slang/slang-linkable-impls.cpp @@ -180,6 +180,7 @@ void CompositeComponentType::acceptVisitor( RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) { SLANG_UNUSED(argCount); @@ -189,15 +190,16 @@ RefPtr<ComponentType::SpecializationInfo> CompositeComponentType::_validateSpeci Index offset = 0; for (auto child : m_childComponents) { - auto childParamCount = child->getSpecializationParamCount(); - SLANG_ASSERT(offset + childParamCount <= argCount); - - auto childInfo = child->_validateSpecializationArgs(args + offset, childParamCount, sink); - + Index consumedArgCount = 0; + auto childInfo = child->_validateSpecializationArgs( + args + offset, + argCount - offset, + consumedArgCount, + sink); specializationInfo->childInfos.add(childInfo); - - offset += childParamCount; + offset += consumedArgCount; } + outConsumedArgCount = offset; return specializationInfo; } @@ -717,11 +719,13 @@ void TypeConformance::acceptVisitor( RefPtr<ComponentType::SpecializationInfo> TypeConformance::_validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) { SLANG_UNUSED(args); SLANG_UNUSED(argCount); SLANG_UNUSED(sink); + outConsumedArgCount = 0; return nullptr; } diff --git a/source/slang/slang-linkable-impls.h b/source/slang/slang-linkable-impls.h index 68a16587d..21ea2fc9f 100644 --- a/source/slang/slang-linkable-impls.h +++ b/source/slang/slang-linkable-impls.h @@ -65,6 +65,7 @@ protected: RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) SLANG_OVERRIDE; public: @@ -165,11 +166,13 @@ protected: RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) SLANG_OVERRIDE { SLANG_UNUSED(args); SLANG_UNUSED(argCount); SLANG_UNUSED(sink); + outConsumedArgCount = 0; return nullptr; } @@ -315,9 +318,10 @@ protected: RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) SLANG_OVERRIDE { - return m_base->_validateSpecializationArgsImpl(args, argCount, sink); + return m_base->_validateSpecializationArgsImpl(args, argCount, outConsumedArgCount, sink); } }; @@ -513,6 +517,7 @@ protected: RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) SLANG_OVERRIDE; private: diff --git a/source/slang/slang-linkable.cpp b/source/slang/slang-linkable.cpp index f7dc28171..0eef62742 100644 --- a/source/slang/slang-linkable.cpp +++ b/source/slang/slang-linkable.cpp @@ -371,9 +371,19 @@ RefPtr<ComponentType> ComponentType::specialize( // (e.g., interface conformance witnesses) that doesn't get // passed explicitly through the API interface. // - RefPtr<SpecializationInfo> specializationInfo = - _validateSpecializationArgs(specializationArgs.getBuffer(), specializationArgCount, sink); - + Index consumedArgCount = 0; + RefPtr<SpecializationInfo> specializationInfo = _validateSpecializationArgs( + specializationArgs.getBuffer(), + specializationArgCount, + consumedArgCount, + sink); + if (consumedArgCount != specializationArgCount) + { + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + Math::Max(consumedArgCount, getSpecializationParamCount(), specializationArgCount)); + } return new SpecializedComponentType(this, specializationInfo, specializationArgs, sink); } @@ -385,21 +395,6 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize( { DiagnosticSink sink(getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); - // First let's check if the number of arguments given matches - // the number of parameters that are present on this component type. - // - auto specializationParamCount = getSpecializationParamCount(); - if (specializationArgCount != specializationParamCount) - { - sink.diagnose( - SourceLoc(), - Diagnostics::mismatchSpecializationArguments, - specializationParamCount, - specializationArgCount); - sink.getBlobIfNeeded(outDiagnostics); - return SLANG_FAIL; - } - List<SpecializationArg> expandedArgs; for (Int aa = 0; aa < specializationArgCount; ++aa) { @@ -411,7 +406,21 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize( case slang::SpecializationArg::Kind::Type: expandedArg.val = asInternal(apiArg.type); break; - + case slang::SpecializationArg::Kind::Expr: + { + auto parsedExpr = parseExprFromString(apiArg.expr, &sink); + if (!parsedExpr) + return SLANG_FAIL; + + SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, &sink); + SemanticsVisitor visitor(&sharedSemanticsContext); + auto checkedExpr = visitor.CheckTerm(parsedExpr); + if (auto typeType = as<TypeType>(checkedExpr->type.type)) + expandedArg.val = typeType->getType(); + else + expandedArg.expr = checkedExpr; + } + break; default: sink.getBlobIfNeeded(outDiagnostics); return SLANG_FAIL; @@ -425,7 +434,8 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::specialize( sink.getBlobIfNeeded(outDiagnostics); *outSpecializedComponentType = specializedComponentType.detach(); - + if (sink.getErrorCount() != 0) + return SLANG_FAIL; return SLANG_OK; } @@ -729,6 +739,18 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getTargetMetadata( return SLANG_OK; } +Expr* ComponentType::parseExprFromString(String exprStr, DiagnosticSink* sink) +{ + auto linkage = getLinkage(); + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + auto astBuilder = linkage->getASTBuilder(); + Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); + Expr* expr = linkage->parseTermString(exprStr, scope); + if (!expr || as<IncompleteExpr>(expr)) + sink->diagnose(SourceLoc(), Diagnostics::syntaxError); + return expr; +} + Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* sink) { // If we've looked up this type name before, @@ -738,14 +760,6 @@ Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* si if (m_types.tryGetValue(typeStr, type)) return type; - - // TODO(JS): For now just used the linkages ASTBuilder to keep on scope - // - // The parseTermString uses the linkage ASTBuilder for it's parsing. - // - // It might be possible to just create a temporary ASTBuilder - the worry though is - // that the parsing sets a member variable in AST node to one of these scopes, and then - // it become a dangling pointer. So for now we go with the linkages. auto astBuilder = getLinkage()->getASTBuilder(); // Otherwise, we need to start looking in diff --git a/source/slang/slang-linkable.h b/source/slang/slang-linkable.h index 04a5c3c88..da97fbf9a 100644 --- a/source/slang/slang-linkable.h +++ b/source/slang/slang-linkable.h @@ -262,7 +262,7 @@ public: /// it only really makes sense on `Module`. /// Type* getTypeFromString(String const& typeStr, DiagnosticSink* sink); - + Expr* parseExprFromString(String expr, DiagnosticSink* sink); Expr* findDeclFromString(String const& name, DiagnosticSink* sink); Expr* tryResolveOverloadedExpr(Expr* exprIn); @@ -348,6 +348,7 @@ public: virtual RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) = 0; /// Validate the given specialization `args` and compute any side-band specialization info. @@ -363,11 +364,12 @@ public: RefPtr<SpecializationInfo> _validateSpecializationArgs( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) { if (argCount == 0) return nullptr; - return _validateSpecializationArgsImpl(args, argCount, sink); + return _validateSpecializationArgsImpl(args, argCount, outConsumedArgCount, sink); } /// Specialize this component type given `specializationArgs` diff --git a/source/slang/slang-module.h b/source/slang/slang-module.h index 7a71f242d..bd911f0f0 100644 --- a/source/slang/slang-module.h +++ b/source/slang/slang-module.h @@ -451,6 +451,7 @@ protected: RefPtr<SpecializationInfo> _validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) SLANG_OVERRIDE; private: diff --git a/tools/slang-unit-test/unit-test-generic-entrypoint.cpp b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp index 741fe35bc..4f0b36edb 100644 --- a/tools/slang-unit-test/unit-test-generic-entrypoint.cpp +++ b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp @@ -18,8 +18,8 @@ SLANG_UNIT_TEST(genericEntryPointCompile) const char* userSourceBody = R"( interface I { int getValue(); } struct X : I { int getValue() { return 100; } } - float4 vertMain<T:I>(uniform T o) { - return float4(o.getValue(), 0, 0, 1); + float4 vertMain<T:I, int n, each U>(uniform T o) { + return float4(o.getValue(), countof(U), n, 1); } )"; ComPtr<slang::IGlobalSession> globalSession; @@ -40,28 +40,103 @@ SLANG_UNIT_TEST(genericEntryPointCompile) diagnosticBlob.writeRef()); SLANG_CHECK(module != nullptr); - ComPtr<slang::IEntryPoint> entryPoint; - module->findAndCheckEntryPoint( - "vertMain<X>", - SLANG_STAGE_VERTEX, - entryPoint.writeRef(), - diagnosticBlob.writeRef()); + // Test 1: Using findAndCheckEntryPoint to supply arguments in string form. + { + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "vertMain<X, 7, int, float>", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(entryPoint != nullptr); + slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); - slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; - ComPtr<slang::IComponentType> composedProgram; - session->createCompositeComponentType( - componentTypes, - 2, - composedProgram.writeRef(), - diagnosticBlob.writeRef()); + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr<slang::IBlob> code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue_0()), 2.0, 7.0, 1.0)")) != -1); + } + + // Test 2: Using `specialize` to supply arguments structurally with reflection types. + { + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "vertMain", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(entryPoint != nullptr); + ComPtr<slang::IComponentType> specializedEntryPoint; + slang::SpecializationArg args[] = { + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("X")), + slang::SpecializationArg::fromExpr("8"), + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("int")), + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("float"))}; + + entryPoint->specialize(args, 4, specializedEntryPoint.writeRef(), nullptr); + SLANG_CHECK_ABORT(specializedEntryPoint != nullptr); + slang::IComponentType* componentTypes[2] = {module, specializedEntryPoint.get()}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); + + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr<slang::IBlob> code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue_0()), 2.0, 8.0, 1.0)")) != -1); + } + + // Test 3: corner case: specialize variadic param with 0 types. + { + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "vertMain", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(entryPoint != nullptr); + ComPtr<slang::IComponentType> specializedEntryPoint; + slang::SpecializationArg args[] = { + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("X")), + slang::SpecializationArg::fromExpr("8")}; + + entryPoint->specialize(args, 2, specializedEntryPoint.writeRef(), nullptr); + SLANG_CHECK_ABORT(specializedEntryPoint != nullptr); + slang::IComponentType* componentTypes[2] = {module, specializedEntryPoint.get()}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); - ComPtr<slang::IComponentType> linkedProgram; - composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); - ComPtr<slang::IBlob> code; - linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + ComPtr<slang::IBlob> code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); - SLANG_CHECK( - UnownedStringSlice((char*)code->getBufferPointer()) - .indexOf(toSlice("vec4(float(X_getValue")) != -1); + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue_0()), 0.0, 8.0, 1.0)")) != -1); + } } |
