From 1eda86377847155ed3f0e0b2e40a105af35bd387 Mon Sep 17 00:00:00 2001 From: lucy96chen <47800040+lucy96chen@users.noreply.github.com> Date: Thu, 3 Feb 2022 16:16:33 -0800 Subject: Added Vulkan implementation for createRayTracingPipelineState() (#2109) * preliminary work on createRayTracingPipelineState for Vulkan * more stuff added to createRayTracingPipelineState * Finished filling in all necessary fields for createRayTracingPipelineState() for Vulkan Co-authored-by: Yong He --- tools/gfx/vulkan/render-vk.cpp | 98 +++++++++++++++++++++++++++++++++++++++++- tools/gfx/vulkan/vk-api.h | 1 + 2 files changed, 98 insertions(+), 1 deletion(-) (limited to 'tools') diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index 2dcc38a5c..0f84df7d2 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -121,7 +121,9 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState( const ComputePipelineStateDesc& desc, IPipelineState** outState) override; - // TODO: Add implementation for createRayTracingPipelineState() - calls VkCreateRayTracingPipelinesKHR + virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState( + const RayTracingPipelineStateDesc& desc, + IPipelineState** outState) override; virtual SLANG_NO_THROW Result SLANG_MCALL createQueryPool( const IQueryPool::Desc& desc, IQueryPool** outPool) override; @@ -907,6 +909,13 @@ public: pipelineDesc.compute = inDesc; initializeBase(pipelineDesc); } + void init(const RayTracingPipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::RayTracing; + pipelineDesc.rayTracing = inDesc; + initializeBase(pipelineDesc); + } BreakableReference m_device; @@ -8634,6 +8643,93 @@ Result VKDevice::createComputePipelineState(const ComputePipelineStateDesc& inDe return SLANG_OK; } +VkPipelineCreateFlags translateFlags(RayTracingPipelineFlags::Enum flags) +{ + VkPipelineCreateFlags vkFlags = 0; + if (flags & RayTracingPipelineFlags::Enum::SkipTriangles) + vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR; + if (flags & RayTracingPipelineFlags::Enum::SkipProcedurals) + vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR; + + return vkFlags; +} + +uint32_t findShaderIndexByName(const VkPipelineShaderStageCreateInfo* stageCreateInfos, size_t stageCount, const char* name) +{ + // TODO: Linear search is inefficient, use a Dictionary? + for (size_t i = 0; i < stageCount; ++i) + { + if (strcmp(stageCreateInfos[i].pName, name)) return (uint32_t)i; + } + return VK_SHADER_UNUSED_KHR; +} + +Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc& desc, IPipelineState** outState) +{ + auto programImpl = static_cast(desc.program); + if (!programImpl->m_rootObjectLayout->m_pipelineLayout) + { + RefPtr pipelineStateImpl = new PipelineStateImpl(this); + pipelineStateImpl->init(desc); + m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); + pipelineStateImpl->establishStrongDeviceReference(); + returnComPtr(outState, pipelineStateImpl); + return SLANG_OK; + } + + VkRayTracingPipelineCreateInfoKHR raytracingPipelineInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR }; + raytracingPipelineInfo.pNext = nullptr; + raytracingPipelineInfo.flags = translateFlags(desc.flags); + + VkPipelineShaderStageCreateInfo shaderStageInfo = { VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO }; + raytracingPipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount(); + raytracingPipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer(); + + List shaderGroupInfos; + for (int32_t i = 0; i < desc.hitGroupCount; ++i) + { + VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR }; + auto& groupDesc = desc.hitGroups[i]; + + shaderGroupInfo.pNext = nullptr; + shaderGroupInfo.type = (groupDesc.intersectionEntryPoint) + ? VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR : VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR; + shaderGroupInfo.generalShader = VK_SHADER_UNUSED_KHR; + shaderGroupInfo.closestHitShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.closestHitEntryPoint); + shaderGroupInfo.anyHitShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.anyHitEntryPoint); + shaderGroupInfo.intersectionShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.intersectionEntryPoint); + shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr; + + shaderGroupInfos.add(shaderGroupInfo); + } + + raytracingPipelineInfo.groupCount = (uint32_t)shaderGroupInfos.getCount(); + raytracingPipelineInfo.pGroups = shaderGroupInfos.getBuffer(); + + raytracingPipelineInfo.maxPipelineRayRecursionDepth = (uint32_t)desc.maxRecursion; + + raytracingPipelineInfo.pLibraryInfo = nullptr; + raytracingPipelineInfo.pLibraryInterface = nullptr; + + raytracingPipelineInfo.pDynamicState = nullptr; + + raytracingPipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout; + raytracingPipelineInfo.basePipelineHandle = VK_NULL_HANDLE; + raytracingPipelineInfo.basePipelineIndex = 0; + + VkPipelineCache pipelineCache = VK_NULL_HANDLE; + VkPipeline pipeline = VK_NULL_HANDLE; + SLANG_VK_CHECK(m_api.vkCreateRayTracingPipelinesKHR(m_device, VK_NULL_HANDLE, pipelineCache, 1, &raytracingPipelineInfo, nullptr, &pipeline)); + + RefPtr pipelineStateImpl = new PipelineStateImpl(this); + pipelineStateImpl->m_pipeline = pipeline; + pipelineStateImpl->init(desc); + pipelineStateImpl->establishStrongDeviceReference(); + m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl); + returnComPtr(outState, pipelineStateImpl); + return SLANG_OK; +} + Result VKDevice::createQueryPool( const IQueryPool::Desc& desc, IQueryPool** outPool) diff --git a/tools/gfx/vulkan/vk-api.h b/tools/gfx/vulkan/vk-api.h index 7df51668f..d2c912161 100644 --- a/tools/gfx/vulkan/vk-api.h +++ b/tools/gfx/vulkan/vk-api.h @@ -162,6 +162,7 @@ namespace gfx { x(vkGetSwapchainImagesKHR) \ x(vkDestroySwapchainKHR) \ x(vkAcquireNextImageKHR) \ + x(vkCreateRayTracingPipelinesKHR) \ /* */ #if SLANG_WINDOWS_FAMILY -- cgit v1.2.3