diff options
| author | lucy96chen <47800040+lucy96chen@users.noreply.github.com> | 2022-02-03 16:16:33 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-02-03 16:16:33 -0800 |
| commit | 1eda86377847155ed3f0e0b2e40a105af35bd387 (patch) | |
| tree | 76b1788c5841a2bd7dc7b0f15d4e4cf5c4c0a971 | |
| parent | e586610a3752eb9325ff688d6245ab20bb635a81 (diff) | |
Added Vulkan implementation for createRayTracingPipelineState() (#2109)
* preliminary work on createRayTracingPipelineState for Vulkan
* more stuff added to createRayTracingPipelineState
* Finished filling in all necessary fields for createRayTracingPipelineState() for Vulkan
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | tools/gfx/vulkan/render-vk.cpp | 98 | ||||
| -rw-r--r-- | tools/gfx/vulkan/vk-api.h | 1 |
2 files changed, 98 insertions, 1 deletions
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index 2dcc38a5c..0f84df7d2 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -121,7 +121,9 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState( const ComputePipelineStateDesc& desc, IPipelineState** outState) override; - // TODO: Add implementation for createRayTracingPipelineState() - calls VkCreateRayTracingPipelinesKHR + virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState( + const RayTracingPipelineStateDesc& desc, + IPipelineState** outState) override; virtual SLANG_NO_THROW Result SLANG_MCALL createQueryPool( const IQueryPool::Desc& desc, IQueryPool** outPool) override; @@ -907,6 +909,13 @@ public: pipelineDesc.compute = inDesc; initializeBase(pipelineDesc); } + void init(const RayTracingPipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::RayTracing; + pipelineDesc.rayTracing = inDesc; + initializeBase(pipelineDesc); + } BreakableReference<VKDevice> m_device; @@ -8634,6 +8643,93 @@ Result VKDevice::createComputePipelineState(const ComputePipelineStateDesc& inDe return SLANG_OK; } +VkPipelineCreateFlags translateFlags(RayTracingPipelineFlags::Enum flags) +{ + VkPipelineCreateFlags vkFlags = 0; + if (flags & RayTracingPipelineFlags::Enum::SkipTriangles) + vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR; + if (flags & RayTracingPipelineFlags::Enum::SkipProcedurals) + vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR; + + return vkFlags; +} + +uint32_t findShaderIndexByName(const VkPipelineShaderStageCreateInfo* stageCreateInfos, size_t stageCount, const char* name) +{ + // TODO: Linear search is inefficient, use a Dictionary? + for (size_t i = 0; i < stageCount; ++i) + { + if (strcmp(stageCreateInfos[i].pName, name)) return (uint32_t)i; + } + return VK_SHADER_UNUSED_KHR; +} + +Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc& desc, IPipelineState** outState) +{ + auto programImpl = static_cast<ShaderProgramImpl*>(desc.program); + if (!programImpl->m_rootObjectLayout->m_pipelineLayout) + { + RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this); + pipelineStateImpl->init(desc); + m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); + pipelineStateImpl->establishStrongDeviceReference(); + returnComPtr(outState, pipelineStateImpl); + return SLANG_OK; + } + + VkRayTracingPipelineCreateInfoKHR raytracingPipelineInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR }; + raytracingPipelineInfo.pNext = nullptr; + raytracingPipelineInfo.flags = translateFlags(desc.flags); + + VkPipelineShaderStageCreateInfo shaderStageInfo = { VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO }; + raytracingPipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount(); + raytracingPipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer(); + + List<VkRayTracingShaderGroupCreateInfoKHR> shaderGroupInfos; + for (int32_t i = 0; i < desc.hitGroupCount; ++i) + { + VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR }; + auto& groupDesc = desc.hitGroups[i]; + + shaderGroupInfo.pNext = nullptr; + shaderGroupInfo.type = (groupDesc.intersectionEntryPoint) + ? VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR : VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR; + shaderGroupInfo.generalShader = VK_SHADER_UNUSED_KHR; + shaderGroupInfo.closestHitShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.closestHitEntryPoint); + shaderGroupInfo.anyHitShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.anyHitEntryPoint); + shaderGroupInfo.intersectionShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.intersectionEntryPoint); + shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr; + + shaderGroupInfos.add(shaderGroupInfo); + } + + raytracingPipelineInfo.groupCount = (uint32_t)shaderGroupInfos.getCount(); + raytracingPipelineInfo.pGroups = shaderGroupInfos.getBuffer(); + + raytracingPipelineInfo.maxPipelineRayRecursionDepth = (uint32_t)desc.maxRecursion; + + raytracingPipelineInfo.pLibraryInfo = nullptr; + raytracingPipelineInfo.pLibraryInterface = nullptr; + + raytracingPipelineInfo.pDynamicState = nullptr; + + raytracingPipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout; + raytracingPipelineInfo.basePipelineHandle = VK_NULL_HANDLE; + raytracingPipelineInfo.basePipelineIndex = 0; + + VkPipelineCache pipelineCache = VK_NULL_HANDLE; + VkPipeline pipeline = VK_NULL_HANDLE; + SLANG_VK_CHECK(m_api.vkCreateRayTracingPipelinesKHR(m_device, VK_NULL_HANDLE, pipelineCache, 1, &raytracingPipelineInfo, nullptr, &pipeline)); + + RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this); + pipelineStateImpl->m_pipeline = pipeline; + pipelineStateImpl->init(desc); + pipelineStateImpl->establishStrongDeviceReference(); + m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); + returnComPtr(outState, pipelineStateImpl); + return SLANG_OK; +} + Result VKDevice::createQueryPool( const IQueryPool::Desc& desc, IQueryPool** outPool) diff --git a/tools/gfx/vulkan/vk-api.h b/tools/gfx/vulkan/vk-api.h index 7df51668f..d2c912161 100644 --- a/tools/gfx/vulkan/vk-api.h +++ b/tools/gfx/vulkan/vk-api.h @@ -162,6 +162,7 @@ namespace gfx { x(vkGetSwapchainImagesKHR) \ x(vkDestroySwapchainKHR) \ x(vkAcquireNextImageKHR) \ + x(vkCreateRayTracingPipelinesKHR) \ /* */ #if SLANG_WINDOWS_FAMILY |
