summaryrefslogtreecommitdiffstats
path: root/tools/gfx/vulkan
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2024-02-06 19:31:28 -0600
committerGitHub <noreply@github.com>2024-02-06 17:31:28 -0800
commit3358b3dd4680cc3f86bbd22b84c242c7f0053775 (patch)
treeebf47d55ea30e7562f91af5662ae166c78a11541 /tools/gfx/vulkan
parentf359df9202140ee575d0c2efd3bb10880edcc331 (diff)
gfx:Add callback to IPipelineCreationAPIDispatcher (#3556)
* gfx:Add callback to IPipelineCreationAPIDispatcher Add the callback to IPipelineCreationAPIDispatcher in Vulkan backend in slang-gfx lib. * gfx:add uuid for vulkan pipeline dispatcher Add a define of SLANG_UUID_IVulkanPipelineCreationAPIDispatcher for Vulkan specific IPipelineCreationAPIDispatcher such that libgfx.so can have special handle to Vulkan pipeline dispatcher without break binary compatibility. In the RendererBase::initialize call, we will provide this new UUID when the DeviceType is Vulkan. * gfx: add new variable to GfxGUID Add new variable to GfxGUID IID_IVulkanPipelineCreationAPIDispatcher with initialization of SLANG_UUID_IVulkanPipelineCreationAPIDispatcher to make the implementation aligned with existing GfxGUID::IID_IPipelineCreationAPIDispatcher. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tools/gfx/vulkan')
-rw-r--r--tools/gfx/vulkan/vk-pipeline-state.cpp50
1 files changed, 43 insertions, 7 deletions
diff --git a/tools/gfx/vulkan/vk-pipeline-state.cpp b/tools/gfx/vulkan/vk-pipeline-state.cpp
index 712811f9a..d861a343c 100644
--- a/tools/gfx/vulkan/vk-pipeline-state.cpp
+++ b/tools/gfx/vulkan/vk-pipeline-state.cpp
@@ -273,8 +273,20 @@ Result PipelineStateImpl::createVKGraphicsPipelineState()
pipelineInfo.basePipelineHandle = VK_NULL_HANDLE;
pipelineInfo.pDynamicState = &dynamicStateInfo;
- SLANG_VK_CHECK(m_device->m_api.vkCreateGraphicsPipelines(
- m_device->m_device, pipelineCache, 1, &pipelineInfo, nullptr, &m_pipeline));
+ if (m_device->m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(
+ m_device->m_pipelineCreationAPIDispatcher->createGraphicsPipelineState(
+ m_device,
+ programImpl->linkedProgram.get(),
+ &pipelineInfo,
+ (void**)&m_pipeline));
+ }
+ else
+ {
+ SLANG_VK_CHECK(m_device->m_api.vkCreateGraphicsPipelines(
+ m_device->m_device, pipelineCache, 1, &pipelineInfo, nullptr, &m_pipeline));
+ }
return SLANG_OK;
}
@@ -287,14 +299,26 @@ Result PipelineStateImpl::createVKComputePipelineState()
SLANG_RETURN_ON_FAIL(programImpl->compileShaders(m_device));
}
- VkPipelineCache pipelineCache = VK_NULL_HANDLE;
-
VkComputePipelineCreateInfo computePipelineInfo = {
- VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO};
+ VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO};
computePipelineInfo.stage = programImpl->m_stageCreateInfos[0];
computePipelineInfo.layout = programImpl->m_rootObjectLayout->m_pipelineLayout;
- SLANG_VK_CHECK(m_device->m_api.vkCreateComputePipelines(
- m_device->m_device, pipelineCache, 1, &computePipelineInfo, nullptr, &m_pipeline));
+
+ if (m_device->m_pipelineCreationAPIDispatcher)
+ {
+ SLANG_RETURN_ON_FAIL(
+ m_device->m_pipelineCreationAPIDispatcher->createComputePipelineState(
+ m_device,
+ programImpl->linkedProgram.get(),
+ &computePipelineInfo,
+ (void**)&m_pipeline));
+ }
+ else
+ {
+ VkPipelineCache pipelineCache = VK_NULL_HANDLE;
+ SLANG_VK_CHECK(m_device->m_api.vkCreateComputePipelines(
+ m_device->m_device, pipelineCache, 1, &computePipelineInfo, nullptr, &m_pipeline));
+ }
return SLANG_OK;
}
@@ -424,6 +448,12 @@ Result RayTracingPipelineStateImpl::createVKRayTracingPipelineState()
raytracingPipelineInfo.basePipelineHandle = VK_NULL_HANDLE;
raytracingPipelineInfo.basePipelineIndex = 0;
+ if (m_device->m_pipelineCreationAPIDispatcher)
+ {
+ m_device->m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(
+ m_device, programImpl->linkedProgram.get());
+ }
+
VkPipelineCache pipelineCache = VK_NULL_HANDLE;
SLANG_VK_CHECK(m_device->m_api.vkCreateRayTracingPipelinesKHR(
m_device->m_device,
@@ -434,6 +464,12 @@ Result RayTracingPipelineStateImpl::createVKRayTracingPipelineState()
nullptr,
&m_pipeline));
shaderGroupCount = shaderGroupInfos.getCount();
+
+ if (m_device->m_pipelineCreationAPIDispatcher)
+ {
+ m_device->m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(
+ m_device, programImpl->linkedProgram.get());
+ }
return SLANG_OK;
}
Result RayTracingPipelineStateImpl::ensureAPIPipelineStateCreated()