summaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-08-09 09:43:25 -0700
committerGitHub <noreply@github.com>2025-08-09 16:43:25 +0000
commitdcdebc1a76a0a6ffbfd6a5805354f8f679c60202 (patch)
tree126d60d157e73e401aacf1e13b400b8533ec8828 /tools
parentfc6aea37483446372425aca8471f0e8bf7c3a910 (diff)
Allow specializing entrypoints with generic value args or variadic types from API (#8119)
Closes #8110. Closes #8011.
Diffstat (limited to 'tools')
-rw-r--r--tools/slang-unit-test/unit-test-generic-entrypoint.cpp119
1 files changed, 97 insertions, 22 deletions
diff --git a/tools/slang-unit-test/unit-test-generic-entrypoint.cpp b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp
index 741fe35bc..4f0b36edb 100644
--- a/tools/slang-unit-test/unit-test-generic-entrypoint.cpp
+++ b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp
@@ -18,8 +18,8 @@ SLANG_UNIT_TEST(genericEntryPointCompile)
const char* userSourceBody = R"(
interface I { int getValue(); }
struct X : I { int getValue() { return 100; } }
- float4 vertMain<T:I>(uniform T o) {
- return float4(o.getValue(), 0, 0, 1);
+ float4 vertMain<T:I, int n, each U>(uniform T o) {
+ return float4(o.getValue(), countof(U), n, 1);
}
)";
ComPtr<slang::IGlobalSession> globalSession;
@@ -40,28 +40,103 @@ SLANG_UNIT_TEST(genericEntryPointCompile)
diagnosticBlob.writeRef());
SLANG_CHECK(module != nullptr);
- ComPtr<slang::IEntryPoint> entryPoint;
- module->findAndCheckEntryPoint(
- "vertMain<X>",
- SLANG_STAGE_VERTEX,
- entryPoint.writeRef(),
- diagnosticBlob.writeRef());
+ // Test 1: Using findAndCheckEntryPoint to supply arguments in string form.
+ {
+ ComPtr<slang::IEntryPoint> entryPoint;
+ module->findAndCheckEntryPoint(
+ "vertMain<X, 7, int, float>",
+ SLANG_STAGE_VERTEX,
+ entryPoint.writeRef(),
+ diagnosticBlob.writeRef());
+ SLANG_CHECK_ABORT(entryPoint != nullptr);
+ slang::IComponentType* componentTypes[2] = {module, entryPoint.get()};
+ ComPtr<slang::IComponentType> composedProgram;
+ session->createCompositeComponentType(
+ componentTypes,
+ 2,
+ composedProgram.writeRef(),
+ diagnosticBlob.writeRef());
- slang::IComponentType* componentTypes[2] = {module, entryPoint.get()};
- ComPtr<slang::IComponentType> composedProgram;
- session->createCompositeComponentType(
- componentTypes,
- 2,
- composedProgram.writeRef(),
- diagnosticBlob.writeRef());
+ ComPtr<slang::IComponentType> linkedProgram;
+ composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
+
+ ComPtr<slang::IBlob> code;
+ linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
+
+ SLANG_CHECK(
+ UnownedStringSlice((char*)code->getBufferPointer())
+ .indexOf(toSlice("vec4(float(X_getValue_0()), 2.0, 7.0, 1.0)")) != -1);
+ }
+
+ // Test 2: Using `specialize` to supply arguments structurally with reflection types.
+ {
+ ComPtr<slang::IEntryPoint> entryPoint;
+ module->findAndCheckEntryPoint(
+ "vertMain",
+ SLANG_STAGE_VERTEX,
+ entryPoint.writeRef(),
+ diagnosticBlob.writeRef());
+ SLANG_CHECK_ABORT(entryPoint != nullptr);
+ ComPtr<slang::IComponentType> specializedEntryPoint;
+ slang::SpecializationArg args[] = {
+ slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("X")),
+ slang::SpecializationArg::fromExpr("8"),
+ slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("int")),
+ slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("float"))};
+
+ entryPoint->specialize(args, 4, specializedEntryPoint.writeRef(), nullptr);
+ SLANG_CHECK_ABORT(specializedEntryPoint != nullptr);
+ slang::IComponentType* componentTypes[2] = {module, specializedEntryPoint.get()};
+ ComPtr<slang::IComponentType> composedProgram;
+ session->createCompositeComponentType(
+ componentTypes,
+ 2,
+ composedProgram.writeRef(),
+ diagnosticBlob.writeRef());
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
+
+ ComPtr<slang::IBlob> code;
+ linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
+
+ SLANG_CHECK(
+ UnownedStringSlice((char*)code->getBufferPointer())
+ .indexOf(toSlice("vec4(float(X_getValue_0()), 2.0, 8.0, 1.0)")) != -1);
+ }
+
+ // Test 3: corner case: specialize variadic param with 0 types.
+ {
+ ComPtr<slang::IEntryPoint> entryPoint;
+ module->findAndCheckEntryPoint(
+ "vertMain",
+ SLANG_STAGE_VERTEX,
+ entryPoint.writeRef(),
+ diagnosticBlob.writeRef());
+ SLANG_CHECK_ABORT(entryPoint != nullptr);
+ ComPtr<slang::IComponentType> specializedEntryPoint;
+ slang::SpecializationArg args[] = {
+ slang::SpecializationArg::fromType(module->getLayout()->findTypeByName("X")),
+ slang::SpecializationArg::fromExpr("8")};
+
+ entryPoint->specialize(args, 2, specializedEntryPoint.writeRef(), nullptr);
+ SLANG_CHECK_ABORT(specializedEntryPoint != nullptr);
+ slang::IComponentType* componentTypes[2] = {module, specializedEntryPoint.get()};
+ ComPtr<slang::IComponentType> composedProgram;
+ session->createCompositeComponentType(
+ componentTypes,
+ 2,
+ composedProgram.writeRef(),
+ diagnosticBlob.writeRef());
- ComPtr<slang::IComponentType> linkedProgram;
- composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
+ ComPtr<slang::IComponentType> linkedProgram;
+ composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
- ComPtr<slang::IBlob> code;
- linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
+ ComPtr<slang::IBlob> code;
+ linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
- SLANG_CHECK(
- UnownedStringSlice((char*)code->getBufferPointer())
- .indexOf(toSlice("vec4(float(X_getValue")) != -1);
+ SLANG_CHECK(
+ UnownedStringSlice((char*)code->getBufferPointer())
+ .indexOf(toSlice("vec4(float(X_getValue_0()), 0.0, 8.0, 1.0)")) != -1);
+ }
}