summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorlucy96chen <47800040+lucy96chen@users.noreply.github.com>2022-02-03 16:16:33 -0800
committerGitHub <noreply@github.com>2022-02-03 16:16:33 -0800
commit1eda86377847155ed3f0e0b2e40a105af35bd387 (patch)
tree76b1788c5841a2bd7dc7b0f15d4e4cf5c4c0a971
parente586610a3752eb9325ff688d6245ab20bb635a81 (diff)
Added Vulkan implementation for createRayTracingPipelineState() (#2109)
* preliminary work on createRayTracingPipelineState for Vulkan * more stuff added to createRayTracingPipelineState * Finished filling in all necessary fields for createRayTracingPipelineState() for Vulkan Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--tools/gfx/vulkan/render-vk.cpp98
-rw-r--r--tools/gfx/vulkan/vk-api.h1
2 files changed, 98 insertions, 1 deletions
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index 2dcc38a5c..0f84df7d2 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -121,7 +121,9 @@ public:
virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState(
const ComputePipelineStateDesc& desc,
IPipelineState** outState) override;
- // TODO: Add implementation for createRayTracingPipelineState() - calls VkCreateRayTracingPipelinesKHR
+ virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState(
+ const RayTracingPipelineStateDesc& desc,
+ IPipelineState** outState) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createQueryPool(
const IQueryPool::Desc& desc,
IQueryPool** outPool) override;
@@ -907,6 +909,13 @@ public:
pipelineDesc.compute = inDesc;
initializeBase(pipelineDesc);
}
+ void init(const RayTracingPipelineStateDesc& inDesc)
+ {
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::RayTracing;
+ pipelineDesc.rayTracing = inDesc;
+ initializeBase(pipelineDesc);
+ }
BreakableReference<VKDevice> m_device;
@@ -8634,6 +8643,93 @@ Result VKDevice::createComputePipelineState(const ComputePipelineStateDesc& inDe
return SLANG_OK;
}
+VkPipelineCreateFlags translateFlags(RayTracingPipelineFlags::Enum flags)
+{
+ VkPipelineCreateFlags vkFlags = 0;
+ if (flags & RayTracingPipelineFlags::Enum::SkipTriangles)
+ vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR;
+ if (flags & RayTracingPipelineFlags::Enum::SkipProcedurals)
+ vkFlags |= VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR;
+
+ return vkFlags;
+}
+
+uint32_t findShaderIndexByName(const VkPipelineShaderStageCreateInfo* stageCreateInfos, size_t stageCount, const char* name)
+{
+ // TODO: Linear search is inefficient, use a Dictionary?
+ for (size_t i = 0; i < stageCount; ++i)
+ {
+ if (strcmp(stageCreateInfos[i].pName, name)) return (uint32_t)i;
+ }
+ return VK_SHADER_UNUSED_KHR;
+}
+
+Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc& desc, IPipelineState** outState)
+{
+ auto programImpl = static_cast<ShaderProgramImpl*>(desc.program);
+ if (!programImpl->m_rootObjectLayout->m_pipelineLayout)
+ {
+ RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this);
+ pipelineStateImpl->init(desc);
+ m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl);
+ pipelineStateImpl->establishStrongDeviceReference();
+ returnComPtr(outState, pipelineStateImpl);
+ return SLANG_OK;
+ }
+
+ VkRayTracingPipelineCreateInfoKHR raytracingPipelineInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR };
+ raytracingPipelineInfo.pNext = nullptr;
+ raytracingPipelineInfo.flags = translateFlags(desc.flags);
+
+ VkPipelineShaderStageCreateInfo shaderStageInfo = { VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO };
+ raytracingPipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount();
+ raytracingPipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer();
+
+ List<VkRayTracingShaderGroupCreateInfoKHR> shaderGroupInfos;
+ for (int32_t i = 0; i < desc.hitGroupCount; ++i)
+ {
+ VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR };
+ auto& groupDesc = desc.hitGroups[i];
+
+ shaderGroupInfo.pNext = nullptr;
+ shaderGroupInfo.type = (groupDesc.intersectionEntryPoint)
+ ? VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR : VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
+ shaderGroupInfo.generalShader = VK_SHADER_UNUSED_KHR;
+ shaderGroupInfo.closestHitShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.closestHitEntryPoint);
+ shaderGroupInfo.anyHitShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.anyHitEntryPoint);
+ shaderGroupInfo.intersectionShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.intersectionEntryPoint);
+ shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr;
+
+ shaderGroupInfos.add(shaderGroupInfo);
+ }
+
+ raytracingPipelineInfo.groupCount = (uint32_t)shaderGroupInfos.getCount();
+ raytracingPipelineInfo.pGroups = shaderGroupInfos.getBuffer();
+
+ raytracingPipelineInfo.maxPipelineRayRecursionDepth = (uint32_t)desc.maxRecursion;
+
+ raytracingPipelineInfo.pLibraryInfo = nullptr;
+ raytracingPipelineInfo.pLibraryInterface = nullptr;
+
+ raytracingPipelineInfo.pDynamicState = nullptr;
+
+ raytracingPipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;
+ raytracingPipelineInfo.basePipelineHandle = VK_NULL_HANDLE;
+ raytracingPipelineInfo.basePipelineIndex = 0;
+
+ VkPipelineCache pipelineCache = VK_NULL_HANDLE;
+ VkPipeline pipeline = VK_NULL_HANDLE;
+ SLANG_VK_CHECK(m_api.vkCreateRayTracingPipelinesKHR(m_device, VK_NULL_HANDLE, pipelineCache, 1, &raytracingPipelineInfo, nullptr, &pipeline));
+
+ RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this);
+ pipelineStateImpl->m_pipeline = pipeline;
+ pipelineStateImpl->init(desc);
+ pipelineStateImpl->establishStrongDeviceReference();
+ m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl);
+ returnComPtr(outState, pipelineStateImpl);
+ return SLANG_OK;
+}
+
Result VKDevice::createQueryPool(
const IQueryPool::Desc& desc,
IQueryPool** outPool)
diff --git a/tools/gfx/vulkan/vk-api.h b/tools/gfx/vulkan/vk-api.h
index 7df51668f..d2c912161 100644
--- a/tools/gfx/vulkan/vk-api.h
+++ b/tools/gfx/vulkan/vk-api.h
@@ -162,6 +162,7 @@ namespace gfx {
x(vkGetSwapchainImagesKHR) \
x(vkDestroySwapchainKHR) \
x(vkAcquireNextImageKHR) \
+ x(vkCreateRayTracingPipelinesKHR) \
/* */
#if SLANG_WINDOWS_FAMILY