From dcdebc1a76a0a6ffbfd6a5805354f8f679c60202 Mon Sep 17 00:00:00 2001 From: Yong He Date: Sat, 9 Aug 2025 09:43:25 -0700 Subject: Allow specializing entrypoints with generic value args or variadic types from API (#8119) Closes #8110. Closes #8011. --- .../unit-test-generic-entrypoint.cpp | 119 +++++++++++++++++---- 1 file changed, 97 insertions(+), 22 deletions(-) (limited to 'tools') diff --git a/tools/slang-unit-test/unit-test-generic-entrypoint.cpp b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp index 741fe35bc..4f0b36edb 100644 --- a/tools/slang-unit-test/unit-test-generic-entrypoint.cpp +++ b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp @@ -18,8 +18,8 @@ SLANG_UNIT_TEST(genericEntryPointCompile) const char* userSourceBody = R"( interface I { int getValue(); } struct X : I { int getValue() { return 100; } } - float4 vertMain(uniform T o) { - return float4(o.getValue(), 0, 0, 1); + float4 vertMain(uniform T o) { + return float4(o.getValue(), countof(U), n, 1); } )"; ComPtr globalSession; @@ -40,28 +40,103 @@ SLANG_UNIT_TEST(genericEntryPointCompile) diagnosticBlob.writeRef()); SLANG_CHECK(module != nullptr); - ComPtr entryPoint; - module->findAndCheckEntryPoint( - "vertMain", - SLANG_STAGE_VERTEX, - entryPoint.writeRef(), - diagnosticBlob.writeRef()); + // Test 1: Using findAndCheckEntryPoint to supply arguments in string form. + { + ComPtr entryPoint; + module->findAndCheckEntryPoint( + "vertMain", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(entryPoint != nullptr); + slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; + ComPtr composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); - slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; - ComPtr composedProgram; - session->createCompositeComponentType( - componentTypes, - 2, - composedProgram.writeRef(), - diagnosticBlob.writeRef()); + ComPtr linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue_0()), 2.0, 7.0, 1.0)")) != -1); + } + + // Test 2: Using `specialize` to supply arguments structurally with reflection types. + { + ComPtr entryPoint; + module->findAndCheckEntryPoint( + "vertMain", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(entryPoint != nullptr); + ComPtr specializedEntryPoint; + slang::SpecializationArg args[] = { + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("X")), + slang::SpecializationArg::fromExpr("8"), + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("int")), + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("float"))}; + + entryPoint->specialize(args, 4, specializedEntryPoint.writeRef(), nullptr); + SLANG_CHECK_ABORT(specializedEntryPoint != nullptr); + slang::IComponentType* componentTypes[2] = {module, specializedEntryPoint.get()}; + ComPtr composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); + + ComPtr linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue_0()), 2.0, 8.0, 1.0)")) != -1); + } + + // Test 3: corner case: specialize variadic param with 0 types. + { + ComPtr entryPoint; + module->findAndCheckEntryPoint( + "vertMain", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(entryPoint != nullptr); + ComPtr specializedEntryPoint; + slang::SpecializationArg args[] = { + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("X")), + slang::SpecializationArg::fromExpr("8")}; + + entryPoint->specialize(args, 2, specializedEntryPoint.writeRef(), nullptr); + SLANG_CHECK_ABORT(specializedEntryPoint != nullptr); + slang::IComponentType* componentTypes[2] = {module, specializedEntryPoint.get()}; + ComPtr composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); - ComPtr linkedProgram; - composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + ComPtr linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); - ComPtr code; - linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + ComPtr code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); - SLANG_CHECK( - UnownedStringSlice((char*)code->getBufferPointer()) - .indexOf(toSlice("vec4(float(X_getValue")) != -1); + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue_0()), 0.0, 8.0, 1.0)")) != -1); + } } -- cgit v1.2.3