diff options
| author | Yong He <yonghe@outlook.com> | 2025-08-09 09:43:25 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-08-09 16:43:25 +0000 |
| commit | dcdebc1a76a0a6ffbfd6a5805354f8f679c60202 (patch) | |
| tree | 126d60d157e73e401aacf1e13b400b8533ec8828 /tools | |
| parent | fc6aea37483446372425aca8471f0e8bf7c3a910 (diff) | |
Allow specializing entrypoints with generic value args or variadic types from API (#8119)
Closes #8110.
Closes #8011.
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/slang-unit-test/unit-test-generic-entrypoint.cpp | 119 |
1 files changed, 97 insertions, 22 deletions
diff --git a/tools/slang-unit-test/unit-test-generic-entrypoint.cpp b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp index 741fe35bc..4f0b36edb 100644 --- a/tools/slang-unit-test/unit-test-generic-entrypoint.cpp +++ b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp @@ -18,8 +18,8 @@ SLANG_UNIT_TEST(genericEntryPointCompile) const char* userSourceBody = R"( interface I { int getValue(); } struct X : I { int getValue() { return 100; } } - float4 vertMain<T:I>(uniform T o) { - return float4(o.getValue(), 0, 0, 1); + float4 vertMain<T:I, int n, each U>(uniform T o) { + return float4(o.getValue(), countof(U), n, 1); } )"; ComPtr<slang::IGlobalSession> globalSession; @@ -40,28 +40,103 @@ SLANG_UNIT_TEST(genericEntryPointCompile) diagnosticBlob.writeRef()); SLANG_CHECK(module != nullptr); - ComPtr<slang::IEntryPoint> entryPoint; - module->findAndCheckEntryPoint( - "vertMain<X>", - SLANG_STAGE_VERTEX, - entryPoint.writeRef(), - diagnosticBlob.writeRef()); + // Test 1: Using findAndCheckEntryPoint to supply arguments in string form. + { + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "vertMain<X, 7, int, float>", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(entryPoint != nullptr); + slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); - slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; - ComPtr<slang::IComponentType> composedProgram; - session->createCompositeComponentType( - componentTypes, - 2, - composedProgram.writeRef(), - diagnosticBlob.writeRef()); + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr<slang::IBlob> code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue_0()), 2.0, 7.0, 1.0)")) != -1); + } + + // Test 2: Using `specialize` to supply arguments structurally with reflection types. + { + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "vertMain", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(entryPoint != nullptr); + ComPtr<slang::IComponentType> specializedEntryPoint; + slang::SpecializationArg args[] = { + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("X")), + slang::SpecializationArg::fromExpr("8"), + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("int")), + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("float"))}; + + entryPoint->specialize(args, 4, specializedEntryPoint.writeRef(), nullptr); + SLANG_CHECK_ABORT(specializedEntryPoint != nullptr); + slang::IComponentType* componentTypes[2] = {module, specializedEntryPoint.get()}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); + + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr<slang::IBlob> code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue_0()), 2.0, 8.0, 1.0)")) != -1); + } + + // Test 3: corner case: specialize variadic param with 0 types. + { + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "vertMain", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK_ABORT(entryPoint != nullptr); + ComPtr<slang::IComponentType> specializedEntryPoint; + slang::SpecializationArg args[] = { + slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("X")), + slang::SpecializationArg::fromExpr("8")}; + + entryPoint->specialize(args, 2, specializedEntryPoint.writeRef(), nullptr); + SLANG_CHECK_ABORT(specializedEntryPoint != nullptr); + slang::IComponentType* componentTypes[2] = {module, specializedEntryPoint.get()}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); - ComPtr<slang::IComponentType> linkedProgram; - composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); - ComPtr<slang::IBlob> code; - linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + ComPtr<slang::IBlob> code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); - SLANG_CHECK( - UnownedStringSlice((char*)code->getBufferPointer()) - .indexOf(toSlice("vec4(float(X_getValue")) != -1); + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue_0()), 0.0, 8.0, 1.0)")) != -1); + } } |
