From 6e24244832d9032f0993cb088af625238096b723 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 17 Dec 2024 21:57:58 -0800 Subject: Support specializing generic entrypoints in `findAndCheckEntryPoint`. (#5890) --- .../unit-test-generic-entrypoint.cpp | 67 +++++++++++++++++++++ tools/slang-unit-test/unit-test-glsl-compile.cpp | 69 ++++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 tools/slang-unit-test/unit-test-generic-entrypoint.cpp create mode 100644 tools/slang-unit-test/unit-test-glsl-compile.cpp (limited to 'tools') diff --git a/tools/slang-unit-test/unit-test-generic-entrypoint.cpp b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp new file mode 100644 index 000000000..741fe35bc --- /dev/null +++ b/tools/slang-unit-test/unit-test-generic-entrypoint.cpp @@ -0,0 +1,67 @@ +// unit-test-generic-entrypoint.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include +#include + +using namespace Slang; + +// Test the compilation API for compiling a specialized generic entrypoint. + +SLANG_UNIT_TEST(genericEntryPointCompile) +{ + const char* userSourceBody = R"( + interface I { int getValue(); } + struct X : I { int getValue() { return 100; } } + float4 vertMain(uniform T o) { + return float4(o.getValue(), 0, 0, 1); + } + )"; + ComPtr globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_GLSL; + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr entryPoint; + module->findAndCheckEntryPoint( + "vertMain", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + + slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; + ComPtr composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); + + ComPtr linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK( + UnownedStringSlice((char*)code->getBufferPointer()) + .indexOf(toSlice("vec4(float(X_getValue")) != -1); +} diff --git a/tools/slang-unit-test/unit-test-glsl-compile.cpp b/tools/slang-unit-test/unit-test-glsl-compile.cpp new file mode 100644 index 000000000..5a49da91b --- /dev/null +++ b/tools/slang-unit-test/unit-test-glsl-compile.cpp @@ -0,0 +1,69 @@ +// unit-test-glsl-compile.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include +#include + +using namespace Slang; + +// Test the compilation API for cross-compiling glsl source to SPIRV. + +SLANG_UNIT_TEST(glslCompile) +{ + const char* userSourceBody = R"( + #version 450 core + layout(location = 0) in vec2 aPosition; + layout(location = 1) in vec4 aColor; + layout(location = 0) out vec4 vColor; + void main() { + vColor = aColor; + gl_Position = vec4(aPosition, 0, 1); + } + )"; + ComPtr globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_SPIRV; + targetDesc.profile = globalSession->findProfile("spirv_1_5"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr entryPoint; + module->findAndCheckEntryPoint( + "main", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + + slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; + ComPtr composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); + + ComPtr linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + ComPtr code; + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK(code != nullptr); +} -- cgit v1.2.3