diff options
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/slang-unit-test/unit-test-find-check-entrypoint.cpp | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp index 8ecab9671..75da9aaf0 100644 --- a/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp +++ b/tools/slang-unit-test/unit-test-find-check-entrypoint.cpp @@ -71,3 +71,68 @@ SLANG_UNIT_TEST(findAndCheckEntryPoint) SLANG_CHECK(code != nullptr); SLANG_CHECK(code->getBufferSize() != 0); } + +// This test reproduces issue #6507, where it was noticed that compilation of +// tests/compute/simple.slang for PTX target generates invalid code. +// TODO: Remove this when issue #4760 is resolved, because at that point +// tests/compute/simple.slang should cover the same issue. +SLANG_UNIT_TEST(cudaCodeGenBug) +{ + // Source for a module that contains an undecorated entrypoint. + const char* userSourceBody = R"( + RWStructuredBuffer<float> outputBuffer; + + [numthreads(4, 1, 1)] + void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) + { + outputBuffer[dispatchThreadID.x] = float(dispatchThreadID.x); + } + )"; + + auto moduleName = "moduleG" + String(Process::getId()); + String userSource = "import " + moduleName + ";\n" + userSourceBody; + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_PTX; + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "computeMain", + SLANG_STAGE_COMPUTE, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK(entryPoint != nullptr); + + ComPtr<slang::IComponentType> compositeProgram; + slang::IComponentType* components[] = {module, entryPoint.get()}; + session->createCompositeComponentType( + components, + 2, + compositeProgram.writeRef(), + diagnosticBlob.writeRef()); + SLANG_CHECK(compositeProgram != nullptr); + + ComPtr<slang::IComponentType> linkedProgram; + compositeProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(linkedProgram != nullptr); + + ComPtr<slang::IBlob> code; + auto res = linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(res == SLANG_OK); + SLANG_CHECK(code != nullptr); + SLANG_CHECK(code->getBufferSize() != 0); +} |
