// unit-test-translation-unit-import.cpp #include "../../source/core/slang-io.h" #include "../../source/core/slang-process.h" #include "slang-com-ptr.h" #include "slang.h" #include "unit-test/slang-unit-test.h" #include #include using namespace Slang; // Test that the IModule::findAndCheckEntryPoint API supports discovering // entrypoints without a [shader] attribute. SLANG_UNIT_TEST(findAndCheckEntryPoint) { // Source for a module that contains an undecorated entrypoint. const char* userSourceBody = R"( float4 fragMain(float4 pos:SV_Position) : SV_Target { return pos; } )"; auto moduleName = "moduleG" + String(Process::getId()); String userSource = "import " + moduleName + ";\n" + userSourceBody; ComPtr globalSession; SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); slang::TargetDesc targetDesc = {}; targetDesc.format = SLANG_SPIRV; targetDesc.profile = globalSession->findProfile("spirv_1_5"); slang::SessionDesc sessionDesc = {}; sessionDesc.targetCount = 1; sessionDesc.targets = &targetDesc; ComPtr session; SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); ComPtr diagnosticBlob; auto module = session->loadModuleFromSourceString( "m", "m.slang", userSourceBody, diagnosticBlob.writeRef()); SLANG_CHECK(module != nullptr); ComPtr entryPoint; module->findAndCheckEntryPoint( "fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef()); SLANG_CHECK(entryPoint != nullptr); ComPtr compositeProgram; slang::IComponentType* components[] = {module, entryPoint.get()}; session->createCompositeComponentType( components, 2, compositeProgram.writeRef(), diagnosticBlob.writeRef()); SLANG_CHECK(compositeProgram != nullptr); ComPtr linkedProgram; compositeProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); SLANG_CHECK(linkedProgram != nullptr); ComPtr code; linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); SLANG_CHECK(code != nullptr); SLANG_CHECK(code->getBufferSize() != 0); } // This test reproduces issue #6507, where it was noticed that compilation of // tests/compute/simple.slang for PTX target generates invalid code. // TODO: Remove this when issue #4760 is resolved, because at that point // tests/compute/simple.slang should cover the same issue. SLANG_UNIT_TEST(cudaCodeGenBug) { // We need the CUDA backend for this test if (!SLANG_SUCCEEDED( unitTestContext->slangGlobalSession->checkPassThroughSupport(SLANG_PASS_THROUGH_NVRTC))) { SLANG_IGNORE_TEST; } // Source for a module that contains an undecorated entrypoint. const char* userSourceBody = R"( RWStructuredBuffer outputBuffer; [numthreads(4, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { outputBuffer[dispatchThreadID.x] = float(dispatchThreadID.x); } )"; auto moduleName = "moduleG" + String(Process::getId()); String userSource = "import " + moduleName + ";\n" + userSourceBody; ComPtr globalSession; SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); slang::TargetDesc targetDesc = {}; targetDesc.format = SLANG_PTX; slang::SessionDesc sessionDesc = {}; sessionDesc.targetCount = 1; sessionDesc.targets = &targetDesc; ComPtr session; SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); ComPtr diagnosticBlob; auto module = session->loadModuleFromSourceString( "m", "m.slang", userSourceBody, diagnosticBlob.writeRef()); SLANG_CHECK(module != nullptr); ComPtr entryPoint; module->findAndCheckEntryPoint( "computeMain", SLANG_STAGE_COMPUTE, entryPoint.writeRef(), diagnosticBlob.writeRef()); SLANG_CHECK(entryPoint != nullptr); ComPtr compositeProgram; slang::IComponentType* components[] = {module, entryPoint.get()}; session->createCompositeComponentType( components, 2, compositeProgram.writeRef(), diagnosticBlob.writeRef()); SLANG_CHECK(compositeProgram != nullptr); ComPtr linkedProgram; compositeProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); SLANG_CHECK(linkedProgram != nullptr); ComPtr code; auto res = linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()); SLANG_CHECK(res == SLANG_OK); SLANG_CHECK(code != nullptr && code->getBufferSize() != 0); }