#include "gfx-test-util.h" #include "slang-com-ptr.h" #include "unit-test/slang-unit-test.h" #define GFX_ENABLE_RENDERDOC_INTEGRATION 0 #define GFX_ENABLE_SPIRV_DEBUG 0 #if GFX_ENABLE_RENDERDOC_INTEGRATION #include "external/renderdoc_app.h" #include #endif using Slang::ComPtr; namespace gfx_test { void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob) { if (diagnosticsBlob != nullptr) { getTestReporter()->message( TestMessageType::Info, (const char*)diagnosticsBlob->getBufferPointer()); } } Result loadComputeProgram( IDevice* device, ComPtr& outShaderProgram, const char* shaderModuleName, const char* entryPointName, slang::ProgramLayout*& slangReflection) { ComPtr slangSession; SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); ComPtr diagnosticsBlob; slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); if (!module) return SLANG_FAIL; ComPtr computeEntryPoint; SLANG_RETURN_ON_FAIL( module->findEntryPointByName(entryPointName, computeEntryPoint.writeRef())); std::vector componentTypes; componentTypes.push_back(module); componentTypes.push_back(computeEntryPoint); ComPtr composedProgram; Result result = slangSession->createCompositeComponentType( componentTypes.data(), componentTypes.size(), composedProgram.writeRef(), diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); SLANG_RETURN_ON_FAIL(result); ComPtr linkedProgram; result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); SLANG_RETURN_ON_FAIL(result); slangReflection = linkedProgram->getLayout(); outShaderProgram = device->createShaderProgram(linkedProgram, diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); return outShaderProgram ? SLANG_OK : SLANG_FAIL; } Result loadComputeProgram( IDevice* device, slang::ISession* slangSession, ComPtr& outShaderProgram, const char* shaderModuleName, const char* entryPointName, slang::ProgramLayout*& slangReflection) { ComPtr diagnosticsBlob; slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); if (!module) return SLANG_FAIL; ComPtr computeEntryPoint; SLANG_RETURN_ON_FAIL( module->findEntryPointByName(entryPointName, computeEntryPoint.writeRef())); std::vector componentTypes; componentTypes.push_back(module); componentTypes.push_back(computeEntryPoint); ComPtr composedProgram; Result result = slangSession->createCompositeComponentType( componentTypes.data(), componentTypes.size(), composedProgram.writeRef(), diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); SLANG_RETURN_ON_FAIL(result); ComPtr linkedProgram; result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); SLANG_RETURN_ON_FAIL(result); slangReflection = linkedProgram->getLayout(); outShaderProgram = device->createShaderProgram(linkedProgram, diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); return outShaderProgram ? SLANG_OK : SLANG_FAIL; } Result loadComputeProgramFromSource( IDevice* device, ComPtr& outShaderProgram, std::string_view source) { auto slangSession = device->getSlangSession(); slang::IModule* module = nullptr; ComPtr diagnosticsBlob; size_t hash = std::hash()(source); std::string moduleName = "source_module_" + std::to_string(hash); auto srcBlob = Slang::UnownedRawBlob::create(source.data(), source.size()); module = slangSession->loadModuleFromSource( moduleName.data(), moduleName.data(), srcBlob, diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); if (!module) return SLANG_FAIL; std::vector> componentTypes; componentTypes.push_back(ComPtr(module)); for (SlangInt32 i = 0; i < module->getDefinedEntryPointCount(); i++) { ComPtr entryPoint; SLANG_RETURN_ON_FAIL(module->getDefinedEntryPoint(i, entryPoint.writeRef())); componentTypes.push_back(ComPtr(entryPoint.get())); } std::vector rawComponentTypes; for (auto& compType : componentTypes) rawComponentTypes.push_back(compType.get()); ComPtr linkedProgram; Result result = slangSession->createCompositeComponentType( rawComponentTypes.data(), rawComponentTypes.size(), linkedProgram.writeRef(), diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); SLANG_RETURN_ON_FAIL(result); outShaderProgram = device->createShaderProgram(linkedProgram, diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); return outShaderProgram ? SLANG_OK : SLANG_FAIL; } Result loadGraphicsProgram( IDevice* device, ComPtr& outShaderProgram, const char* shaderModuleName, const char* vertexEntryPointName, const char* fragmentEntryPointName, slang::ProgramLayout*& slangReflection) { ComPtr slangSession; SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); ComPtr diagnosticsBlob; slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); if (!module) return SLANG_FAIL; ComPtr vertexEntryPoint; SLANG_RETURN_ON_FAIL( module->findEntryPointByName(vertexEntryPointName, vertexEntryPoint.writeRef())); ComPtr fragmentEntryPoint; SLANG_RETURN_ON_FAIL( module->findEntryPointByName(fragmentEntryPointName, fragmentEntryPoint.writeRef())); std::vector componentTypes; componentTypes.push_back(module); componentTypes.push_back(vertexEntryPoint); componentTypes.push_back(fragmentEntryPoint); ComPtr composedProgram; Result result = slangSession->createCompositeComponentType( componentTypes.data(), componentTypes.size(), composedProgram.writeRef(), diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); SLANG_RETURN_ON_FAIL(result); ComPtr linkedProgram; result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); SLANG_RETURN_ON_FAIL(result); slangReflection = linkedProgram->getLayout(); outShaderProgram = device->createShaderProgram(linkedProgram, diagnosticsBlob.writeRef()); diagnoseIfNeeded(diagnosticsBlob); return outShaderProgram ? SLANG_OK : SLANG_FAIL; } Slang::ComPtr createTestingDevice( UnitTestContext* context, DeviceType deviceType, Slang::List additionalSearchPaths) { Slang::ComPtr device; DeviceDesc deviceDesc = {}; deviceDesc.deviceType = deviceType; deviceDesc.slang.slangGlobalSession = context->slangGlobalSession; Slang::List searchPaths = getSlangSearchPaths(); searchPaths.addRange(additionalSearchPaths); deviceDesc.slang.searchPaths = searchPaths.getBuffer(); deviceDesc.slang.searchPathCount = searchPaths.getCount(); std::vector preprocessorMacros; std::vector compilerOptions; slang::CompilerOptionEntry emitSpirvDirectlyEntry; emitSpirvDirectlyEntry.name = slang::CompilerOptionName::EmitSpirvDirectly; emitSpirvDirectlyEntry.value.intValue0 = 1; compilerOptions.push_back(emitSpirvDirectlyEntry); #if DEBUG_SPIRV slang::CompilerOptionEntry debugLevelCompilerOptionEntry = {}; debugLevelCompilerOptionEntry.name = slang::CompilerOptionName::DebugInformation; debugLevelCompilerOptionEntry.value.intValue0 = SLANG_DEBUG_INFO_LEVEL_STANDARD; compilerOptions.push_back(debugLevelCompilerOptionEntry); #endif #if DUMP_INTERMEDIATES slang::CompilerOptionEntry dumpIntermediatesOptionEntry = {}; dumpIntermediatesOptionEntry.name = slang::CompilerOptionName::DumpIntermediates; dumpIntermediatesOptionEntry.value.intValue0 = 1; compilerOptions.push_back(dumpIntermediatesOptionEntry); #endif deviceDesc.slang.preprocessorMacros = preprocessorMacros.data(); deviceDesc.slang.preprocessorMacroCount = preprocessorMacros.size(); deviceDesc.slang.compilerOptionEntries = compilerOptions.data(); deviceDesc.slang.compilerOptionEntryCount = compilerOptions.size(); if (context->enableDebugLayers) { deviceDesc.enableValidation = context->enableDebugLayers; deviceDesc.debugCallback = context->debugCallback; getRHI()->enableDebugLayers(); } D3D12DeviceExtendedDesc extDesc = {}; if (deviceType == DeviceType::D3D12) { extDesc.rootParameterShaderAttributeName = "root"; deviceDesc.next = &extDesc; } auto createDeviceResult = getRHI()->createDevice(deviceDesc, device.writeRef()); if (SLANG_FAILED(createDeviceResult)) { SLANG_IGNORE_TEST } return device; } Slang::List getSlangSearchPaths() { Slang::List searchPaths; searchPaths.add(""); searchPaths.add("../../tools/gfx-unit-test"); searchPaths.add("tools/gfx-unit-test"); return searchPaths; } #if GFX_ENABLE_RENDERDOC_INTEGRATION RENDERDOC_API_1_1_2* rdoc_api = NULL; void initializeRenderDoc() { if (HMODULE mod = GetModuleHandleA("renderdoc.dll")) { pRENDERDOC_GetAPI RENDERDOC_GetAPI = (pRENDERDOC_GetAPI)GetProcAddress(mod, "RENDERDOC_GetAPI"); int ret = RENDERDOC_GetAPI(eRENDERDOC_API_Version_1_1_2, (void**)&rdoc_api); assert(ret == 1); } } void renderDocBeginFrame() { if (!rdoc_api) initializeRenderDoc(); if (rdoc_api) rdoc_api->StartFrameCapture(nullptr, nullptr); } void renderDocEndFrame() { if (rdoc_api) rdoc_api->EndFrameCapture(nullptr, nullptr); _fgetchar(); } #else void initializeRenderDoc() {} void renderDocBeginFrame() {} void renderDocEndFrame() {} #endif } // namespace gfx_test